1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8 MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9 Result, WorkflowDispatch,
10 job::JobStatus,
11 rate_limit::{RateLimitConfig, RateLimitKey},
12 workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22pub enum RouteResult {
24 Query(Value),
26 Mutation(Value),
28 Job(Value),
30 Workflow(Value),
32}
33
34pub struct FunctionRouter {
36 registry: Arc<FunctionRegistry>,
37 db: Database,
38 http_client: CircuitBreakerClient,
39 job_dispatcher: Option<Arc<dyn JobDispatch>>,
40 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
41 rate_limiter: HybridRateLimiter,
42 query_cache: QueryCache,
43}
44
45impl FunctionRouter {
46 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
48 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
49 Self {
50 registry,
51 db,
52 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
53 job_dispatcher: None,
54 workflow_dispatcher: None,
55 rate_limiter,
56 query_cache: QueryCache::new(),
57 }
58 }
59
60 pub fn with_http_client(
62 registry: Arc<FunctionRegistry>,
63 db: Database,
64 http_client: CircuitBreakerClient,
65 ) -> Self {
66 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
67 Self {
68 registry,
69 db,
70 http_client,
71 job_dispatcher: None,
72 workflow_dispatcher: None,
73 rate_limiter,
74 query_cache: QueryCache::new(),
75 }
76 }
77
78 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
80 self.job_dispatcher = Some(dispatcher);
81 self
82 }
83
84 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
86 self.workflow_dispatcher = Some(dispatcher);
87 self
88 }
89
90 pub async fn route(
91 &self,
92 function_name: &str,
93 args: Value,
94 auth: AuthContext,
95 request: RequestMetadata,
96 ) -> Result<RouteResult> {
97 if let Some(entry) = self.registry.get(function_name) {
98 self.check_auth(entry.info(), &auth)?;
99 self.check_rate_limit(entry.info(), function_name, &auth, &request)
100 .await?;
101 Self::check_identity_args(function_name, &args, &auth, !entry.info().is_public)?;
102
103 return match entry {
104 FunctionEntry::Query { handler, info, .. } => {
105 let pool = if info.consistent {
106 self.db.primary().clone()
107 } else {
108 self.db.read_pool().clone()
109 };
110
111 let auth_scope = Self::auth_cache_scope(&auth);
112 if let Some(ttl) = info.cache_ttl {
113 if let Some(cached) =
114 self.query_cache
115 .get(function_name, &args, auth_scope.as_deref())
116 {
117 return Ok(RouteResult::Query(Value::clone(&cached)));
118 }
119
120 let ctx = QueryContext::new(pool, auth, request);
121 let result = handler(&ctx, args.clone()).await?;
122
123 self.query_cache.set(
124 function_name,
125 &args,
126 auth_scope.as_deref(),
127 result.clone(),
128 Duration::from_secs(ttl),
129 );
130
131 Ok(RouteResult::Query(result))
132 } else {
133 let ctx = QueryContext::new(pool, auth, request);
134 let result = handler(&ctx, args).await?;
135 Ok(RouteResult::Query(result))
136 }
137 }
138 FunctionEntry::Mutation { handler, info } => {
139 if info.transactional {
140 self.execute_transactional(handler, args, auth, request)
141 .await
142 } else {
143 let ctx = MutationContext::with_dispatch(
145 self.db.primary().clone(),
146 auth,
147 request,
148 self.http_client.clone(),
149 self.job_dispatcher.clone(),
150 self.workflow_dispatcher.clone(),
151 );
152 let result = handler(&ctx, args).await?;
153 Ok(RouteResult::Mutation(result))
154 }
155 }
156 };
157 }
158
159 if let Some(ref job_dispatcher) = self.job_dispatcher
160 && let Some(job_info) = job_dispatcher.get_info(function_name)
161 {
162 self.check_job_auth(&job_info, &auth)?;
163 Self::check_identity_args(function_name, &args, &auth, !job_info.is_public)?;
164 match job_dispatcher
165 .dispatch_by_name(function_name, args.clone(), auth.principal_id())
166 .await
167 {
168 Ok(job_id) => {
169 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
170 }
171 Err(ForgeError::NotFound(_)) => {}
172 Err(e) => return Err(e),
173 }
174 }
175
176 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
177 && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
178 {
179 self.check_workflow_auth(&workflow_info, &auth)?;
180 Self::check_identity_args(function_name, &args, &auth, !workflow_info.is_public)?;
181 match workflow_dispatcher
182 .start_by_name(function_name, args.clone(), auth.principal_id())
183 .await
184 {
185 Ok(workflow_id) => {
186 return Ok(RouteResult::Workflow(
187 serde_json::json!({ "workflow_id": workflow_id }),
188 ));
189 }
190 Err(ForgeError::NotFound(_)) => {}
191 Err(e) => return Err(e),
192 }
193 }
194
195 Err(ForgeError::NotFound(format!(
196 "Function '{}' not found",
197 function_name
198 )))
199 }
200
201 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
202 if info.is_public {
203 return Ok(());
204 }
205
206 if !auth.is_authenticated() {
207 return Err(ForgeError::Unauthorized("Authentication required".into()));
208 }
209
210 if let Some(role) = info.required_role
211 && !auth.has_role(role)
212 {
213 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
214 }
215
216 Ok(())
217 }
218
219 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
220 if info.is_public {
221 return Ok(());
222 }
223
224 if !auth.is_authenticated() {
225 return Err(ForgeError::Unauthorized("Authentication required".into()));
226 }
227
228 if let Some(role) = info.required_role
229 && !auth.has_role(role)
230 {
231 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
232 }
233
234 Ok(())
235 }
236
237 fn check_workflow_auth(
238 &self,
239 info: &forge_core::workflow::WorkflowInfo,
240 auth: &AuthContext,
241 ) -> Result<()> {
242 if info.is_public {
243 return Ok(());
244 }
245
246 if !auth.is_authenticated() {
247 return Err(ForgeError::Unauthorized("Authentication required".into()));
248 }
249
250 if let Some(role) = info.required_role
251 && !auth.has_role(role)
252 {
253 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
254 }
255
256 Ok(())
257 }
258
259 async fn check_rate_limit(
261 &self,
262 info: &FunctionInfo,
263 function_name: &str,
264 auth: &AuthContext,
265 request: &RequestMetadata,
266 ) -> Result<()> {
267 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
269 (Some(r), Some(p)) => (r, p),
270 _ => return Ok(()),
271 };
272
273 let key_str = info.rate_limit_key.unwrap_or("user");
275 let key_type: RateLimitKey = match key_str.parse() {
276 Ok(k) => k,
277 Err(_) => {
278 tracing::warn!(
279 function = %function_name,
280 key = %key_str,
281 "Invalid rate limit key, falling back to 'user'"
282 );
283 RateLimitKey::default()
284 }
285 };
286
287 let config =
288 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
289
290 let bucket_key = self
292 .rate_limiter
293 .build_key(key_type, function_name, auth, request);
294
295 self.rate_limiter.enforce(&bucket_key, &config).await?;
297
298 Ok(())
299 }
300
301 fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
302 if !auth.is_authenticated() {
303 return Some("anon".to_string());
304 }
305
306 let mut roles = auth.roles().to_vec();
308 roles.sort();
309 roles.dedup();
310
311 let mut claims = BTreeMap::new();
312 for (k, v) in auth.claims() {
313 claims.insert(k.clone(), v.clone());
314 }
315
316 use std::collections::hash_map::DefaultHasher;
317 use std::hash::{Hash, Hasher};
318
319 let mut hasher = DefaultHasher::new();
320 roles.hash(&mut hasher);
321 serde_json::to_string(&claims)
322 .unwrap_or_default()
323 .hash(&mut hasher);
324
325 let principal = auth
326 .principal_id()
327 .unwrap_or_else(|| "authenticated".to_string());
328
329 Some(format!(
330 "subject:{principal}:scope:{:016x}",
331 hasher.finish()
332 ))
333 }
334
335 fn check_identity_args(
336 function_name: &str,
337 args: &Value,
338 auth: &AuthContext,
339 enforce_scope: bool,
340 ) -> Result<()> {
341 if auth.is_admin() {
342 return Ok(());
343 }
344
345 let Some(obj) = args.as_object() else {
346 if enforce_scope && auth.is_authenticated() {
347 return Err(ForgeError::Forbidden(format!(
348 "Function '{function_name}' must include identity or tenant scope arguments"
349 )));
350 }
351 return Ok(());
352 };
353
354 let mut principal_values: Vec<String> = Vec::new();
355 if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
356 principal_values.push(user_id);
357 }
358 if let Some(subject) = auth.principal_id()
359 && !principal_values.iter().any(|v| v == &subject)
360 {
361 principal_values.push(subject);
362 }
363
364 let mut has_scope_key = false;
365
366 for key in [
367 "user_id",
368 "userId",
369 "owner_id",
370 "ownerId",
371 "owner_subject",
372 "ownerSubject",
373 "subject",
374 "sub",
375 "principal_id",
376 "principalId",
377 ] {
378 let Some(value) = obj.get(key) else {
379 continue;
380 };
381 has_scope_key = true;
382
383 if !auth.is_authenticated() {
384 return Err(ForgeError::Unauthorized(format!(
385 "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
386 )));
387 }
388
389 let Value::String(actual) = value else {
390 return Err(ForgeError::InvalidArgument(format!(
391 "Function '{function_name}' argument '{key}' must be a non-empty string"
392 )));
393 };
394
395 if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
396 return Err(ForgeError::Forbidden(format!(
397 "Function '{function_name}' argument '{key}' does not match authenticated principal"
398 )));
399 }
400 }
401
402 for key in ["tenant_id", "tenantId"] {
403 let Some(value) = obj.get(key) else {
404 continue;
405 };
406 has_scope_key = true;
407
408 if !auth.is_authenticated() {
409 return Err(ForgeError::Unauthorized(format!(
410 "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
411 )));
412 }
413
414 let expected = auth
415 .claim("tenant_id")
416 .and_then(|v| v.as_str())
417 .ok_or_else(|| {
418 ForgeError::Forbidden(format!(
419 "Function '{function_name}' argument '{key}' is not allowed for this principal"
420 ))
421 })?;
422
423 let Value::String(actual) = value else {
424 return Err(ForgeError::InvalidArgument(format!(
425 "Function '{function_name}' argument '{key}' must be a non-empty string"
426 )));
427 };
428
429 if actual.trim().is_empty() || actual != expected {
430 return Err(ForgeError::Forbidden(format!(
431 "Function '{function_name}' argument '{key}' does not match authenticated tenant"
432 )));
433 }
434 }
435
436 if enforce_scope && auth.is_authenticated() && !has_scope_key {
437 return Err(ForgeError::Forbidden(format!(
438 "Function '{function_name}' must include identity or tenant scope arguments"
439 )));
440 }
441
442 Ok(())
443 }
444
445 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
447 self.registry.get(function_name).map(|e| e.kind())
448 }
449
450 pub fn has_function(&self, function_name: &str) -> bool {
452 self.registry.get(function_name).is_some()
453 }
454
455 async fn execute_transactional(
456 &self,
457 handler: &BoxedMutationFn,
458 args: Value,
459 auth: AuthContext,
460 request: RequestMetadata,
461 ) -> Result<RouteResult> {
462 let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
463
464 async {
465 let primary = self.db.primary();
466 let tx = primary
467 .begin()
468 .await
469 .map_err(|e| ForgeError::Database(e.to_string()))?;
470
471 let job_dispatcher = self.job_dispatcher.clone();
472 let job_lookup: forge_core::JobInfoLookup =
473 Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
474
475 let (ctx, tx_handle, outbox) = MutationContext::with_transaction(
476 primary.clone(),
477 tx,
478 auth,
479 request,
480 self.http_client.clone(),
481 job_lookup,
482 );
483
484 match handler(&ctx, args).await {
485 Ok(value) => {
486 let buffer = {
487 let guard = outbox.lock().expect("outbox mutex poisoned");
488 OutboxBuffer {
489 jobs: guard.jobs.clone(),
490 workflows: guard.workflows.clone(),
491 }
492 };
493
494 let mut tx = Arc::try_unwrap(tx_handle)
495 .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
496 .into_inner();
497
498 for job in &buffer.jobs {
499 Self::insert_job(&mut tx, job).await?;
500 }
501
502 for workflow in &buffer.workflows {
503 let version = self
504 .workflow_dispatcher
505 .as_ref()
506 .and_then(|d| d.get_info(&workflow.workflow_name))
507 .map(|info| info.version)
508 .ok_or_else(|| {
509 ForgeError::NotFound(format!(
510 "Workflow '{}' not found",
511 workflow.workflow_name
512 ))
513 })?;
514 Self::insert_workflow(&mut tx, workflow, version).await?;
515 }
516
517 tx.commit()
518 .await
519 .map_err(|e| ForgeError::Database(e.to_string()))?;
520
521 Ok(RouteResult::Mutation(value))
522 }
523 Err(e) => Err(e),
524 }
525 }
526 .instrument(span)
527 .await
528 }
529
530 async fn insert_job(
531 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
532 job: &PendingJob,
533 ) -> Result<()> {
534 let now = Utc::now();
535 sqlx::query(
536 r#"
537 INSERT INTO forge_jobs (
538 id, job_type, input, job_context, status, priority, attempts, max_attempts,
539 worker_capability, owner_subject, scheduled_at, created_at
540 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
541 "#,
542 )
543 .bind(job.id)
544 .bind(&job.job_type)
545 .bind(&job.args)
546 .bind(&job.context)
547 .bind(JobStatus::Pending.as_str())
548 .bind(job.priority)
549 .bind(0i32)
550 .bind(job.max_attempts)
551 .bind(&job.worker_capability)
552 .bind(&job.owner_subject)
553 .bind(now)
554 .bind(now)
555 .execute(&mut **tx)
556 .await
557 .map_err(|e| ForgeError::Database(e.to_string()))?;
558
559 Ok(())
560 }
561
562 async fn insert_workflow(
563 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
564 workflow: &PendingWorkflow,
565 version: u32,
566 ) -> Result<()> {
567 let now = Utc::now();
568 sqlx::query(
569 r#"
570 INSERT INTO forge_workflow_runs (
571 id, workflow_name, version, owner_subject, input, status, current_step,
572 step_results, started_at, trace_id
573 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
574 "#,
575 )
576 .bind(workflow.id)
577 .bind(&workflow.workflow_name)
578 .bind(version as i32)
579 .bind(&workflow.owner_subject)
580 .bind(&workflow.input)
581 .bind(WorkflowStatus::Created.as_str())
582 .bind(Option::<String>::None)
583 .bind(serde_json::json!({}))
584 .bind(now)
585 .bind(workflow.id.to_string())
586 .execute(&mut **tx)
587 .await
588 .map_err(|e| ForgeError::Database(e.to_string()))?;
589
590 Ok(())
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use std::collections::HashMap;
598
599 #[test]
600 fn test_check_auth_public() {
601 let info = FunctionInfo {
602 name: "test",
603 description: None,
604 kind: FunctionKind::Query,
605 required_role: None,
606 is_public: true,
607 cache_ttl: None,
608 timeout: None,
609 rate_limit_requests: None,
610 rate_limit_per_secs: None,
611 rate_limit_key: None,
612 log_level: None,
613 table_dependencies: &[],
614 selected_columns: &[],
615 transactional: false,
616 consistent: false,
617 };
618
619 let _auth = AuthContext::unauthenticated();
620
621 assert!(info.is_public);
624 }
625
626 #[test]
627 fn test_identity_args_reject_cross_user_value() {
628 let user_id = uuid::Uuid::new_v4();
629 let auth = AuthContext::authenticated(
630 user_id,
631 vec!["user".to_string()],
632 HashMap::from([(
633 "sub".to_string(),
634 serde_json::Value::String(user_id.to_string()),
635 )]),
636 );
637 let args = serde_json::json!({
638 "user_id": uuid::Uuid::new_v4().to_string()
639 });
640
641 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
642 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
643 }
644
645 #[test]
646 fn test_identity_args_allow_matching_subject() {
647 let sub = "user_123";
648 let auth = AuthContext::authenticated_without_uuid(
649 vec!["user".to_string()],
650 HashMap::from([(
651 "sub".to_string(),
652 serde_json::Value::String(sub.to_string()),
653 )]),
654 );
655 let args = serde_json::json!({
656 "subject": sub
657 });
658
659 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
660 assert!(result.is_ok());
661 }
662
663 #[test]
664 fn test_identity_args_require_auth_for_identity_keys() {
665 let auth = AuthContext::unauthenticated();
666 let args = serde_json::json!({
667 "user_id": uuid::Uuid::new_v4().to_string()
668 });
669
670 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
671 assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
672 }
673
674 #[test]
675 fn test_identity_args_require_scope_for_non_public_calls() {
676 let user_id = uuid::Uuid::new_v4();
677 let auth = AuthContext::authenticated(
678 user_id,
679 vec!["user".to_string()],
680 HashMap::from([(
681 "sub".to_string(),
682 serde_json::Value::String(user_id.to_string()),
683 )]),
684 );
685
686 let result =
687 FunctionRouter::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
688 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
689 }
690
691 #[test]
692 fn test_auth_cache_scope_changes_with_claims() {
693 let user_id = uuid::Uuid::new_v4();
694 let auth_a = AuthContext::authenticated(
695 user_id,
696 vec!["user".to_string()],
697 HashMap::from([
698 (
699 "sub".to_string(),
700 serde_json::Value::String(user_id.to_string()),
701 ),
702 (
703 "tenant_id".to_string(),
704 serde_json::Value::String("tenant-a".to_string()),
705 ),
706 ]),
707 );
708 let auth_b = AuthContext::authenticated(
709 user_id,
710 vec!["user".to_string()],
711 HashMap::from([
712 (
713 "sub".to_string(),
714 serde_json::Value::String(user_id.to_string()),
715 ),
716 (
717 "tenant_id".to_string(),
718 serde_json::Value::String("tenant-b".to_string()),
719 ),
720 ]),
721 );
722
723 let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
724 let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
725 assert_ne!(scope_a, scope_b);
726 }
727}