1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8 MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9 Result, WorkflowDispatch,
10 job::JobStatus,
11 rate_limit::{RateLimitConfig, RateLimitKey},
12 workflow::WorkflowStatus,
13};
14use serde_json::Value;
15
16use super::cache::QueryCache;
17use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
18use crate::db::Database;
19use crate::rate_limit::RateLimiter;
20
21pub enum RouteResult {
23 Query(Value),
25 Mutation(Value),
27 Job(Value),
29 Workflow(Value),
31}
32
33pub struct FunctionRouter {
35 registry: Arc<FunctionRegistry>,
36 db: Database,
37 http_client: CircuitBreakerClient,
38 job_dispatcher: Option<Arc<dyn JobDispatch>>,
39 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
40 rate_limiter: RateLimiter,
41 query_cache: QueryCache,
42}
43
44impl FunctionRouter {
45 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
47 let rate_limiter = RateLimiter::new(db.primary().clone());
48 Self {
49 registry,
50 db,
51 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
52 job_dispatcher: None,
53 workflow_dispatcher: None,
54 rate_limiter,
55 query_cache: QueryCache::new(),
56 }
57 }
58
59 pub fn with_http_client(
61 registry: Arc<FunctionRegistry>,
62 db: Database,
63 http_client: CircuitBreakerClient,
64 ) -> Self {
65 let rate_limiter = RateLimiter::new(db.primary().clone());
66 Self {
67 registry,
68 db,
69 http_client,
70 job_dispatcher: None,
71 workflow_dispatcher: None,
72 rate_limiter,
73 query_cache: QueryCache::new(),
74 }
75 }
76
77 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
79 self.job_dispatcher = Some(dispatcher);
80 self
81 }
82
83 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
85 self.workflow_dispatcher = Some(dispatcher);
86 self
87 }
88
89 pub async fn route(
90 &self,
91 function_name: &str,
92 args: Value,
93 auth: AuthContext,
94 request: RequestMetadata,
95 ) -> Result<RouteResult> {
96 if let Some(entry) = self.registry.get(function_name) {
97 self.check_auth(entry.info(), &auth)?;
98 self.check_rate_limit(entry.info(), function_name, &auth, &request)
99 .await?;
100 Self::check_identity_args(function_name, &args, &auth, !entry.info().is_public)?;
101
102 return match entry {
103 FunctionEntry::Query { handler, info, .. } => {
104 let auth_scope = Self::auth_cache_scope(&auth);
105 if let Some(ttl) = info.cache_ttl {
106 if let Some(cached) =
107 self.query_cache
108 .get(function_name, &args, auth_scope.as_deref())
109 {
110 return Ok(RouteResult::Query(Value::clone(&cached)));
111 }
112
113 let ctx = QueryContext::new(self.db.read_pool().clone(), auth, request);
115 let result = handler(&ctx, args.clone()).await?;
116
117 self.query_cache.set(
118 function_name,
119 &args,
120 auth_scope.as_deref(),
121 result.clone(),
122 Duration::from_secs(ttl),
123 );
124
125 Ok(RouteResult::Query(result))
126 } else {
127 let ctx = QueryContext::new(self.db.read_pool().clone(), auth, request);
129 let result = handler(&ctx, args).await?;
130 Ok(RouteResult::Query(result))
131 }
132 }
133 FunctionEntry::Mutation { handler, info } => {
134 if info.transactional {
135 self.execute_transactional(handler, args, auth, request)
136 .await
137 } else {
138 let ctx = MutationContext::with_dispatch(
140 self.db.primary().clone(),
141 auth,
142 request,
143 self.http_client.clone(),
144 self.job_dispatcher.clone(),
145 self.workflow_dispatcher.clone(),
146 );
147 let result = handler(&ctx, args).await?;
148 Ok(RouteResult::Mutation(result))
149 }
150 }
151 };
152 }
153
154 if let Some(ref job_dispatcher) = self.job_dispatcher
155 && let Some(job_info) = job_dispatcher.get_info(function_name)
156 {
157 self.check_job_auth(&job_info, &auth)?;
158 Self::check_identity_args(function_name, &args, &auth, !job_info.is_public)?;
159 match job_dispatcher
160 .dispatch_by_name(function_name, args.clone(), auth.principal_id())
161 .await
162 {
163 Ok(job_id) => {
164 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
165 }
166 Err(ForgeError::NotFound(_)) => {}
167 Err(e) => return Err(e),
168 }
169 }
170
171 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
172 && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
173 {
174 self.check_workflow_auth(&workflow_info, &auth)?;
175 Self::check_identity_args(function_name, &args, &auth, !workflow_info.is_public)?;
176 match workflow_dispatcher
177 .start_by_name(function_name, args.clone(), auth.principal_id())
178 .await
179 {
180 Ok(workflow_id) => {
181 return Ok(RouteResult::Workflow(
182 serde_json::json!({ "workflow_id": workflow_id }),
183 ));
184 }
185 Err(ForgeError::NotFound(_)) => {}
186 Err(e) => return Err(e),
187 }
188 }
189
190 Err(ForgeError::NotFound(format!(
191 "Function '{}' not found",
192 function_name
193 )))
194 }
195
196 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
197 if info.is_public {
198 return Ok(());
199 }
200
201 if !auth.is_authenticated() {
202 return Err(ForgeError::Unauthorized("Authentication required".into()));
203 }
204
205 if let Some(role) = info.required_role
206 && !auth.has_role(role)
207 {
208 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
209 }
210
211 Ok(())
212 }
213
214 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
215 if info.is_public {
216 return Ok(());
217 }
218
219 if !auth.is_authenticated() {
220 return Err(ForgeError::Unauthorized("Authentication required".into()));
221 }
222
223 if let Some(role) = info.required_role
224 && !auth.has_role(role)
225 {
226 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
227 }
228
229 Ok(())
230 }
231
232 fn check_workflow_auth(
233 &self,
234 info: &forge_core::workflow::WorkflowInfo,
235 auth: &AuthContext,
236 ) -> Result<()> {
237 if info.is_public {
238 return Ok(());
239 }
240
241 if !auth.is_authenticated() {
242 return Err(ForgeError::Unauthorized("Authentication required".into()));
243 }
244
245 if let Some(role) = info.required_role
246 && !auth.has_role(role)
247 {
248 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
249 }
250
251 Ok(())
252 }
253
254 async fn check_rate_limit(
256 &self,
257 info: &FunctionInfo,
258 function_name: &str,
259 auth: &AuthContext,
260 request: &RequestMetadata,
261 ) -> Result<()> {
262 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
264 (Some(r), Some(p)) => (r, p),
265 _ => return Ok(()),
266 };
267
268 let key_str = info.rate_limit_key.unwrap_or("user");
270 let key_type: RateLimitKey = match key_str.parse() {
271 Ok(k) => k,
272 Err(_) => {
273 tracing::warn!(
274 function = %function_name,
275 key = %key_str,
276 "Invalid rate limit key, falling back to 'user'"
277 );
278 RateLimitKey::default()
279 }
280 };
281
282 let config =
283 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
284
285 let bucket_key = self
287 .rate_limiter
288 .build_key(key_type, function_name, auth, request);
289
290 self.rate_limiter.enforce(&bucket_key, &config).await?;
292
293 Ok(())
294 }
295
296 fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
297 if !auth.is_authenticated() {
298 return Some("anon".to_string());
299 }
300
301 let mut roles = auth.roles().to_vec();
303 roles.sort();
304 roles.dedup();
305
306 let mut claims = BTreeMap::new();
307 for (k, v) in auth.claims() {
308 claims.insert(k.clone(), v.clone());
309 }
310
311 use std::collections::hash_map::DefaultHasher;
312 use std::hash::{Hash, Hasher};
313
314 let mut hasher = DefaultHasher::new();
315 roles.hash(&mut hasher);
316 serde_json::to_string(&claims)
317 .unwrap_or_default()
318 .hash(&mut hasher);
319
320 let principal = auth
321 .principal_id()
322 .unwrap_or_else(|| "authenticated".to_string());
323
324 Some(format!(
325 "subject:{principal}:scope:{:016x}",
326 hasher.finish()
327 ))
328 }
329
330 fn check_identity_args(
331 function_name: &str,
332 args: &Value,
333 auth: &AuthContext,
334 enforce_scope: bool,
335 ) -> Result<()> {
336 if auth.is_admin() {
337 return Ok(());
338 }
339
340 let Some(obj) = args.as_object() else {
341 if enforce_scope && auth.is_authenticated() {
342 return Err(ForgeError::Forbidden(format!(
343 "Function '{function_name}' must include identity or tenant scope arguments"
344 )));
345 }
346 return Ok(());
347 };
348
349 let mut principal_values: Vec<String> = Vec::new();
350 if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
351 principal_values.push(user_id);
352 }
353 if let Some(subject) = auth.principal_id()
354 && !principal_values.iter().any(|v| v == &subject)
355 {
356 principal_values.push(subject);
357 }
358
359 let mut has_scope_key = false;
360
361 for key in [
362 "user_id",
363 "userId",
364 "owner_id",
365 "ownerId",
366 "owner_subject",
367 "ownerSubject",
368 "subject",
369 "sub",
370 "principal_id",
371 "principalId",
372 ] {
373 let Some(value) = obj.get(key) else {
374 continue;
375 };
376 has_scope_key = true;
377
378 if !auth.is_authenticated() {
379 return Err(ForgeError::Unauthorized(format!(
380 "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
381 )));
382 }
383
384 let Value::String(actual) = value else {
385 return Err(ForgeError::InvalidArgument(format!(
386 "Function '{function_name}' argument '{key}' must be a non-empty string"
387 )));
388 };
389
390 if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
391 return Err(ForgeError::Forbidden(format!(
392 "Function '{function_name}' argument '{key}' does not match authenticated principal"
393 )));
394 }
395 }
396
397 for key in ["tenant_id", "tenantId"] {
398 let Some(value) = obj.get(key) else {
399 continue;
400 };
401 has_scope_key = true;
402
403 if !auth.is_authenticated() {
404 return Err(ForgeError::Unauthorized(format!(
405 "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
406 )));
407 }
408
409 let expected = auth
410 .claim("tenant_id")
411 .and_then(|v| v.as_str())
412 .ok_or_else(|| {
413 ForgeError::Forbidden(format!(
414 "Function '{function_name}' argument '{key}' is not allowed for this principal"
415 ))
416 })?;
417
418 let Value::String(actual) = value else {
419 return Err(ForgeError::InvalidArgument(format!(
420 "Function '{function_name}' argument '{key}' must be a non-empty string"
421 )));
422 };
423
424 if actual.trim().is_empty() || actual != expected {
425 return Err(ForgeError::Forbidden(format!(
426 "Function '{function_name}' argument '{key}' does not match authenticated tenant"
427 )));
428 }
429 }
430
431 if enforce_scope && auth.is_authenticated() && !has_scope_key {
432 return Err(ForgeError::Forbidden(format!(
433 "Function '{function_name}' must include identity or tenant scope arguments"
434 )));
435 }
436
437 Ok(())
438 }
439
440 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
442 self.registry.get(function_name).map(|e| e.kind())
443 }
444
445 pub fn has_function(&self, function_name: &str) -> bool {
447 self.registry.get(function_name).is_some()
448 }
449
450 async fn execute_transactional(
451 &self,
452 handler: &BoxedMutationFn,
453 args: Value,
454 auth: AuthContext,
455 request: RequestMetadata,
456 ) -> Result<RouteResult> {
457 let primary = self.db.primary();
459 let tx = primary
460 .begin()
461 .await
462 .map_err(|e| ForgeError::Database(e.to_string()))?;
463
464 let job_dispatcher = self.job_dispatcher.clone();
465 let job_lookup: forge_core::JobInfoLookup =
466 Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
467
468 let (ctx, tx_handle, outbox) = MutationContext::with_transaction(
469 primary.clone(),
470 tx,
471 auth,
472 request,
473 self.http_client.clone(),
474 job_lookup,
475 );
476
477 match handler(&ctx, args).await {
478 Ok(value) => {
479 let buffer = {
480 let guard = outbox.lock().expect("outbox mutex poisoned");
481 OutboxBuffer {
482 jobs: guard.jobs.clone(),
483 workflows: guard.workflows.clone(),
484 }
485 };
486
487 let mut tx = Arc::try_unwrap(tx_handle)
488 .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
489 .into_inner();
490
491 for job in &buffer.jobs {
492 Self::insert_job(&mut tx, job).await?;
493 }
494
495 for workflow in &buffer.workflows {
496 let version = self
497 .workflow_dispatcher
498 .as_ref()
499 .and_then(|d| d.get_info(&workflow.workflow_name))
500 .map(|info| info.version)
501 .ok_or_else(|| {
502 ForgeError::NotFound(format!(
503 "Workflow '{}' not found",
504 workflow.workflow_name
505 ))
506 })?;
507 Self::insert_workflow(&mut tx, workflow, version).await?;
508 }
509
510 tx.commit()
511 .await
512 .map_err(|e| ForgeError::Database(e.to_string()))?;
513
514 Ok(RouteResult::Mutation(value))
515 }
516 Err(e) => Err(e),
517 }
518 }
519
520 async fn insert_job(
521 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
522 job: &PendingJob,
523 ) -> Result<()> {
524 let now = Utc::now();
525 sqlx::query(
526 r#"
527 INSERT INTO forge_jobs (
528 id, job_type, input, job_context, status, priority, attempts, max_attempts,
529 worker_capability, owner_subject, scheduled_at, created_at
530 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
531 "#,
532 )
533 .bind(job.id)
534 .bind(&job.job_type)
535 .bind(&job.args)
536 .bind(&job.context)
537 .bind(JobStatus::Pending.as_str())
538 .bind(job.priority)
539 .bind(0i32)
540 .bind(job.max_attempts)
541 .bind(&job.worker_capability)
542 .bind(&job.owner_subject)
543 .bind(now)
544 .bind(now)
545 .execute(&mut **tx)
546 .await
547 .map_err(|e| ForgeError::Database(e.to_string()))?;
548
549 Ok(())
550 }
551
552 async fn insert_workflow(
553 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
554 workflow: &PendingWorkflow,
555 version: u32,
556 ) -> Result<()> {
557 let now = Utc::now();
558 sqlx::query(
559 r#"
560 INSERT INTO forge_workflow_runs (
561 id, workflow_name, version, owner_subject, input, status, current_step,
562 step_results, started_at, trace_id
563 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
564 "#,
565 )
566 .bind(workflow.id)
567 .bind(&workflow.workflow_name)
568 .bind(version as i32)
569 .bind(&workflow.owner_subject)
570 .bind(&workflow.input)
571 .bind(WorkflowStatus::Created.as_str())
572 .bind(Option::<String>::None)
573 .bind(serde_json::json!({}))
574 .bind(now)
575 .bind(workflow.id.to_string())
576 .execute(&mut **tx)
577 .await
578 .map_err(|e| ForgeError::Database(e.to_string()))?;
579
580 Ok(())
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use std::collections::HashMap;
588
589 #[test]
590 fn test_check_auth_public() {
591 let info = FunctionInfo {
592 name: "test",
593 description: None,
594 kind: FunctionKind::Query,
595 required_role: None,
596 is_public: true,
597 cache_ttl: None,
598 timeout: None,
599 rate_limit_requests: None,
600 rate_limit_per_secs: None,
601 rate_limit_key: None,
602 log_level: None,
603 table_dependencies: &[],
604 transactional: false,
605 };
606
607 let _auth = AuthContext::unauthenticated();
608
609 assert!(info.is_public);
612 }
613
614 #[test]
615 fn test_identity_args_reject_cross_user_value() {
616 let user_id = uuid::Uuid::new_v4();
617 let auth = AuthContext::authenticated(
618 user_id,
619 vec!["user".to_string()],
620 HashMap::from([(
621 "sub".to_string(),
622 serde_json::Value::String(user_id.to_string()),
623 )]),
624 );
625 let args = serde_json::json!({
626 "user_id": uuid::Uuid::new_v4().to_string()
627 });
628
629 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
630 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
631 }
632
633 #[test]
634 fn test_identity_args_allow_matching_subject() {
635 let sub = "user_123";
636 let auth = AuthContext::authenticated_without_uuid(
637 vec!["user".to_string()],
638 HashMap::from([(
639 "sub".to_string(),
640 serde_json::Value::String(sub.to_string()),
641 )]),
642 );
643 let args = serde_json::json!({
644 "subject": sub
645 });
646
647 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
648 assert!(result.is_ok());
649 }
650
651 #[test]
652 fn test_identity_args_require_auth_for_identity_keys() {
653 let auth = AuthContext::unauthenticated();
654 let args = serde_json::json!({
655 "user_id": uuid::Uuid::new_v4().to_string()
656 });
657
658 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
659 assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
660 }
661
662 #[test]
663 fn test_identity_args_require_scope_for_non_public_calls() {
664 let user_id = uuid::Uuid::new_v4();
665 let auth = AuthContext::authenticated(
666 user_id,
667 vec!["user".to_string()],
668 HashMap::from([(
669 "sub".to_string(),
670 serde_json::Value::String(user_id.to_string()),
671 )]),
672 );
673
674 let result =
675 FunctionRouter::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
676 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
677 }
678
679 #[test]
680 fn test_auth_cache_scope_changes_with_claims() {
681 let user_id = uuid::Uuid::new_v4();
682 let auth_a = AuthContext::authenticated(
683 user_id,
684 vec!["user".to_string()],
685 HashMap::from([
686 (
687 "sub".to_string(),
688 serde_json::Value::String(user_id.to_string()),
689 ),
690 (
691 "tenant_id".to_string(),
692 serde_json::Value::String("tenant-a".to_string()),
693 ),
694 ]),
695 );
696 let auth_b = AuthContext::authenticated(
697 user_id,
698 vec!["user".to_string()],
699 HashMap::from([
700 (
701 "sub".to_string(),
702 serde_json::Value::String(user_id.to_string()),
703 ),
704 (
705 "tenant_id".to_string(),
706 serde_json::Value::String("tenant-b".to_string()),
707 ),
708 ]),
709 );
710
711 let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
712 let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
713 assert_ne!(scope_a, scope_b);
714 }
715}