1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8 MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9 Result, WorkflowDispatch,
10 job::JobStatus,
11 rate_limit::{RateLimitConfig, RateLimitKey},
12 workflow::WorkflowStatus,
13};
14use serde_json::Value;
15
16use super::cache::QueryCache;
17use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
18use crate::db::Database;
19use crate::rate_limit::HybridRateLimiter;
20
21pub enum RouteResult {
23 Query(Value),
25 Mutation(Value),
27 Job(Value),
29 Workflow(Value),
31}
32
33pub struct FunctionRouter {
35 registry: Arc<FunctionRegistry>,
36 db: Database,
37 http_client: CircuitBreakerClient,
38 job_dispatcher: Option<Arc<dyn JobDispatch>>,
39 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
40 rate_limiter: HybridRateLimiter,
41 query_cache: QueryCache,
42}
43
44impl FunctionRouter {
45 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
47 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
48 Self {
49 registry,
50 db,
51 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
52 job_dispatcher: None,
53 workflow_dispatcher: None,
54 rate_limiter,
55 query_cache: QueryCache::new(),
56 }
57 }
58
59 pub fn with_http_client(
61 registry: Arc<FunctionRegistry>,
62 db: Database,
63 http_client: CircuitBreakerClient,
64 ) -> Self {
65 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
66 Self {
67 registry,
68 db,
69 http_client,
70 job_dispatcher: None,
71 workflow_dispatcher: None,
72 rate_limiter,
73 query_cache: QueryCache::new(),
74 }
75 }
76
77 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
79 self.job_dispatcher = Some(dispatcher);
80 self
81 }
82
83 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
85 self.workflow_dispatcher = Some(dispatcher);
86 self
87 }
88
89 pub async fn route(
90 &self,
91 function_name: &str,
92 args: Value,
93 auth: AuthContext,
94 request: RequestMetadata,
95 ) -> Result<RouteResult> {
96 if let Some(entry) = self.registry.get(function_name) {
97 self.check_auth(entry.info(), &auth)?;
98 self.check_rate_limit(entry.info(), function_name, &auth, &request)
99 .await?;
100 Self::check_identity_args(function_name, &args, &auth, !entry.info().is_public)?;
101
102 return match entry {
103 FunctionEntry::Query { handler, info, .. } => {
104 let pool = if info.consistent {
105 self.db.primary().clone()
106 } else {
107 self.db.read_pool().clone()
108 };
109
110 let auth_scope = Self::auth_cache_scope(&auth);
111 if let Some(ttl) = info.cache_ttl {
112 if let Some(cached) =
113 self.query_cache
114 .get(function_name, &args, auth_scope.as_deref())
115 {
116 return Ok(RouteResult::Query(Value::clone(&cached)));
117 }
118
119 let ctx = QueryContext::new(pool, auth, request);
120 let result = handler(&ctx, args.clone()).await?;
121
122 self.query_cache.set(
123 function_name,
124 &args,
125 auth_scope.as_deref(),
126 result.clone(),
127 Duration::from_secs(ttl),
128 );
129
130 Ok(RouteResult::Query(result))
131 } else {
132 let ctx = QueryContext::new(pool, auth, request);
133 let result = handler(&ctx, args).await?;
134 Ok(RouteResult::Query(result))
135 }
136 }
137 FunctionEntry::Mutation { handler, info } => {
138 if info.transactional {
139 self.execute_transactional(handler, args, auth, request)
140 .await
141 } else {
142 let ctx = MutationContext::with_dispatch(
144 self.db.primary().clone(),
145 auth,
146 request,
147 self.http_client.clone(),
148 self.job_dispatcher.clone(),
149 self.workflow_dispatcher.clone(),
150 );
151 let result = handler(&ctx, args).await?;
152 Ok(RouteResult::Mutation(result))
153 }
154 }
155 };
156 }
157
158 if let Some(ref job_dispatcher) = self.job_dispatcher
159 && let Some(job_info) = job_dispatcher.get_info(function_name)
160 {
161 self.check_job_auth(&job_info, &auth)?;
162 Self::check_identity_args(function_name, &args, &auth, !job_info.is_public)?;
163 match job_dispatcher
164 .dispatch_by_name(function_name, args.clone(), auth.principal_id())
165 .await
166 {
167 Ok(job_id) => {
168 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
169 }
170 Err(ForgeError::NotFound(_)) => {}
171 Err(e) => return Err(e),
172 }
173 }
174
175 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
176 && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
177 {
178 self.check_workflow_auth(&workflow_info, &auth)?;
179 Self::check_identity_args(function_name, &args, &auth, !workflow_info.is_public)?;
180 match workflow_dispatcher
181 .start_by_name(function_name, args.clone(), auth.principal_id())
182 .await
183 {
184 Ok(workflow_id) => {
185 return Ok(RouteResult::Workflow(
186 serde_json::json!({ "workflow_id": workflow_id }),
187 ));
188 }
189 Err(ForgeError::NotFound(_)) => {}
190 Err(e) => return Err(e),
191 }
192 }
193
194 Err(ForgeError::NotFound(format!(
195 "Function '{}' not found",
196 function_name
197 )))
198 }
199
200 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
201 if info.is_public {
202 return Ok(());
203 }
204
205 if !auth.is_authenticated() {
206 return Err(ForgeError::Unauthorized("Authentication required".into()));
207 }
208
209 if let Some(role) = info.required_role
210 && !auth.has_role(role)
211 {
212 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
213 }
214
215 Ok(())
216 }
217
218 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
219 if info.is_public {
220 return Ok(());
221 }
222
223 if !auth.is_authenticated() {
224 return Err(ForgeError::Unauthorized("Authentication required".into()));
225 }
226
227 if let Some(role) = info.required_role
228 && !auth.has_role(role)
229 {
230 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
231 }
232
233 Ok(())
234 }
235
236 fn check_workflow_auth(
237 &self,
238 info: &forge_core::workflow::WorkflowInfo,
239 auth: &AuthContext,
240 ) -> Result<()> {
241 if info.is_public {
242 return Ok(());
243 }
244
245 if !auth.is_authenticated() {
246 return Err(ForgeError::Unauthorized("Authentication required".into()));
247 }
248
249 if let Some(role) = info.required_role
250 && !auth.has_role(role)
251 {
252 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
253 }
254
255 Ok(())
256 }
257
258 async fn check_rate_limit(
260 &self,
261 info: &FunctionInfo,
262 function_name: &str,
263 auth: &AuthContext,
264 request: &RequestMetadata,
265 ) -> Result<()> {
266 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
268 (Some(r), Some(p)) => (r, p),
269 _ => return Ok(()),
270 };
271
272 let key_str = info.rate_limit_key.unwrap_or("user");
274 let key_type: RateLimitKey = match key_str.parse() {
275 Ok(k) => k,
276 Err(_) => {
277 tracing::warn!(
278 function = %function_name,
279 key = %key_str,
280 "Invalid rate limit key, falling back to 'user'"
281 );
282 RateLimitKey::default()
283 }
284 };
285
286 let config =
287 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
288
289 let bucket_key = self
291 .rate_limiter
292 .build_key(key_type, function_name, auth, request);
293
294 self.rate_limiter.enforce(&bucket_key, &config).await?;
296
297 Ok(())
298 }
299
300 fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
301 if !auth.is_authenticated() {
302 return Some("anon".to_string());
303 }
304
305 let mut roles = auth.roles().to_vec();
307 roles.sort();
308 roles.dedup();
309
310 let mut claims = BTreeMap::new();
311 for (k, v) in auth.claims() {
312 claims.insert(k.clone(), v.clone());
313 }
314
315 use std::collections::hash_map::DefaultHasher;
316 use std::hash::{Hash, Hasher};
317
318 let mut hasher = DefaultHasher::new();
319 roles.hash(&mut hasher);
320 serde_json::to_string(&claims)
321 .unwrap_or_default()
322 .hash(&mut hasher);
323
324 let principal = auth
325 .principal_id()
326 .unwrap_or_else(|| "authenticated".to_string());
327
328 Some(format!(
329 "subject:{principal}:scope:{:016x}",
330 hasher.finish()
331 ))
332 }
333
334 fn check_identity_args(
335 function_name: &str,
336 args: &Value,
337 auth: &AuthContext,
338 enforce_scope: bool,
339 ) -> Result<()> {
340 if auth.is_admin() {
341 return Ok(());
342 }
343
344 let Some(obj) = args.as_object() else {
345 if enforce_scope && auth.is_authenticated() {
346 return Err(ForgeError::Forbidden(format!(
347 "Function '{function_name}' must include identity or tenant scope arguments"
348 )));
349 }
350 return Ok(());
351 };
352
353 let mut principal_values: Vec<String> = Vec::new();
354 if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
355 principal_values.push(user_id);
356 }
357 if let Some(subject) = auth.principal_id()
358 && !principal_values.iter().any(|v| v == &subject)
359 {
360 principal_values.push(subject);
361 }
362
363 let mut has_scope_key = false;
364
365 for key in [
366 "user_id",
367 "userId",
368 "owner_id",
369 "ownerId",
370 "owner_subject",
371 "ownerSubject",
372 "subject",
373 "sub",
374 "principal_id",
375 "principalId",
376 ] {
377 let Some(value) = obj.get(key) else {
378 continue;
379 };
380 has_scope_key = true;
381
382 if !auth.is_authenticated() {
383 return Err(ForgeError::Unauthorized(format!(
384 "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
385 )));
386 }
387
388 let Value::String(actual) = value else {
389 return Err(ForgeError::InvalidArgument(format!(
390 "Function '{function_name}' argument '{key}' must be a non-empty string"
391 )));
392 };
393
394 if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
395 return Err(ForgeError::Forbidden(format!(
396 "Function '{function_name}' argument '{key}' does not match authenticated principal"
397 )));
398 }
399 }
400
401 for key in ["tenant_id", "tenantId"] {
402 let Some(value) = obj.get(key) else {
403 continue;
404 };
405 has_scope_key = true;
406
407 if !auth.is_authenticated() {
408 return Err(ForgeError::Unauthorized(format!(
409 "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
410 )));
411 }
412
413 let expected = auth
414 .claim("tenant_id")
415 .and_then(|v| v.as_str())
416 .ok_or_else(|| {
417 ForgeError::Forbidden(format!(
418 "Function '{function_name}' argument '{key}' is not allowed for this principal"
419 ))
420 })?;
421
422 let Value::String(actual) = value else {
423 return Err(ForgeError::InvalidArgument(format!(
424 "Function '{function_name}' argument '{key}' must be a non-empty string"
425 )));
426 };
427
428 if actual.trim().is_empty() || actual != expected {
429 return Err(ForgeError::Forbidden(format!(
430 "Function '{function_name}' argument '{key}' does not match authenticated tenant"
431 )));
432 }
433 }
434
435 if enforce_scope && auth.is_authenticated() && !has_scope_key {
436 return Err(ForgeError::Forbidden(format!(
437 "Function '{function_name}' must include identity or tenant scope arguments"
438 )));
439 }
440
441 Ok(())
442 }
443
444 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
446 self.registry.get(function_name).map(|e| e.kind())
447 }
448
449 pub fn has_function(&self, function_name: &str) -> bool {
451 self.registry.get(function_name).is_some()
452 }
453
454 async fn execute_transactional(
455 &self,
456 handler: &BoxedMutationFn,
457 args: Value,
458 auth: AuthContext,
459 request: RequestMetadata,
460 ) -> Result<RouteResult> {
461 let primary = self.db.primary();
463 let tx = primary
464 .begin()
465 .await
466 .map_err(|e| ForgeError::Database(e.to_string()))?;
467
468 let job_dispatcher = self.job_dispatcher.clone();
469 let job_lookup: forge_core::JobInfoLookup =
470 Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
471
472 let (ctx, tx_handle, outbox) = MutationContext::with_transaction(
473 primary.clone(),
474 tx,
475 auth,
476 request,
477 self.http_client.clone(),
478 job_lookup,
479 );
480
481 match handler(&ctx, args).await {
482 Ok(value) => {
483 let buffer = {
484 let guard = outbox.lock().expect("outbox mutex poisoned");
485 OutboxBuffer {
486 jobs: guard.jobs.clone(),
487 workflows: guard.workflows.clone(),
488 }
489 };
490
491 let mut tx = Arc::try_unwrap(tx_handle)
492 .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
493 .into_inner();
494
495 for job in &buffer.jobs {
496 Self::insert_job(&mut tx, job).await?;
497 }
498
499 for workflow in &buffer.workflows {
500 let version = self
501 .workflow_dispatcher
502 .as_ref()
503 .and_then(|d| d.get_info(&workflow.workflow_name))
504 .map(|info| info.version)
505 .ok_or_else(|| {
506 ForgeError::NotFound(format!(
507 "Workflow '{}' not found",
508 workflow.workflow_name
509 ))
510 })?;
511 Self::insert_workflow(&mut tx, workflow, version).await?;
512 }
513
514 tx.commit()
515 .await
516 .map_err(|e| ForgeError::Database(e.to_string()))?;
517
518 Ok(RouteResult::Mutation(value))
519 }
520 Err(e) => Err(e),
521 }
522 }
523
524 async fn insert_job(
525 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
526 job: &PendingJob,
527 ) -> Result<()> {
528 let now = Utc::now();
529 sqlx::query(
530 r#"
531 INSERT INTO forge_jobs (
532 id, job_type, input, job_context, status, priority, attempts, max_attempts,
533 worker_capability, owner_subject, scheduled_at, created_at
534 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
535 "#,
536 )
537 .bind(job.id)
538 .bind(&job.job_type)
539 .bind(&job.args)
540 .bind(&job.context)
541 .bind(JobStatus::Pending.as_str())
542 .bind(job.priority)
543 .bind(0i32)
544 .bind(job.max_attempts)
545 .bind(&job.worker_capability)
546 .bind(&job.owner_subject)
547 .bind(now)
548 .bind(now)
549 .execute(&mut **tx)
550 .await
551 .map_err(|e| ForgeError::Database(e.to_string()))?;
552
553 Ok(())
554 }
555
556 async fn insert_workflow(
557 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
558 workflow: &PendingWorkflow,
559 version: u32,
560 ) -> Result<()> {
561 let now = Utc::now();
562 sqlx::query(
563 r#"
564 INSERT INTO forge_workflow_runs (
565 id, workflow_name, version, owner_subject, input, status, current_step,
566 step_results, started_at, trace_id
567 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
568 "#,
569 )
570 .bind(workflow.id)
571 .bind(&workflow.workflow_name)
572 .bind(version as i32)
573 .bind(&workflow.owner_subject)
574 .bind(&workflow.input)
575 .bind(WorkflowStatus::Created.as_str())
576 .bind(Option::<String>::None)
577 .bind(serde_json::json!({}))
578 .bind(now)
579 .bind(workflow.id.to_string())
580 .execute(&mut **tx)
581 .await
582 .map_err(|e| ForgeError::Database(e.to_string()))?;
583
584 Ok(())
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use std::collections::HashMap;
592
593 #[test]
594 fn test_check_auth_public() {
595 let info = FunctionInfo {
596 name: "test",
597 description: None,
598 kind: FunctionKind::Query,
599 required_role: None,
600 is_public: true,
601 cache_ttl: None,
602 timeout: None,
603 rate_limit_requests: None,
604 rate_limit_per_secs: None,
605 rate_limit_key: None,
606 log_level: None,
607 table_dependencies: &[],
608 selected_columns: &[],
609 transactional: false,
610 consistent: false,
611 };
612
613 let _auth = AuthContext::unauthenticated();
614
615 assert!(info.is_public);
618 }
619
620 #[test]
621 fn test_identity_args_reject_cross_user_value() {
622 let user_id = uuid::Uuid::new_v4();
623 let auth = AuthContext::authenticated(
624 user_id,
625 vec!["user".to_string()],
626 HashMap::from([(
627 "sub".to_string(),
628 serde_json::Value::String(user_id.to_string()),
629 )]),
630 );
631 let args = serde_json::json!({
632 "user_id": uuid::Uuid::new_v4().to_string()
633 });
634
635 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
636 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
637 }
638
639 #[test]
640 fn test_identity_args_allow_matching_subject() {
641 let sub = "user_123";
642 let auth = AuthContext::authenticated_without_uuid(
643 vec!["user".to_string()],
644 HashMap::from([(
645 "sub".to_string(),
646 serde_json::Value::String(sub.to_string()),
647 )]),
648 );
649 let args = serde_json::json!({
650 "subject": sub
651 });
652
653 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
654 assert!(result.is_ok());
655 }
656
657 #[test]
658 fn test_identity_args_require_auth_for_identity_keys() {
659 let auth = AuthContext::unauthenticated();
660 let args = serde_json::json!({
661 "user_id": uuid::Uuid::new_v4().to_string()
662 });
663
664 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
665 assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
666 }
667
668 #[test]
669 fn test_identity_args_require_scope_for_non_public_calls() {
670 let user_id = uuid::Uuid::new_v4();
671 let auth = AuthContext::authenticated(
672 user_id,
673 vec!["user".to_string()],
674 HashMap::from([(
675 "sub".to_string(),
676 serde_json::Value::String(user_id.to_string()),
677 )]),
678 );
679
680 let result =
681 FunctionRouter::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
682 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
683 }
684
685 #[test]
686 fn test_auth_cache_scope_changes_with_claims() {
687 let user_id = uuid::Uuid::new_v4();
688 let auth_a = AuthContext::authenticated(
689 user_id,
690 vec!["user".to_string()],
691 HashMap::from([
692 (
693 "sub".to_string(),
694 serde_json::Value::String(user_id.to_string()),
695 ),
696 (
697 "tenant_id".to_string(),
698 serde_json::Value::String("tenant-a".to_string()),
699 ),
700 ]),
701 );
702 let auth_b = AuthContext::authenticated(
703 user_id,
704 vec!["user".to_string()],
705 HashMap::from([
706 (
707 "sub".to_string(),
708 serde_json::Value::String(user_id.to_string()),
709 ),
710 (
711 "tenant_id".to_string(),
712 serde_json::Value::String("tenant-b".to_string()),
713 ),
714 ]),
715 );
716
717 let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
718 let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
719 assert_ne!(scope_a, scope_b);
720 }
721}