1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8 MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9 Result, WorkflowDispatch,
10 job::JobStatus,
11 rate_limit::{RateLimitConfig, RateLimitKey},
12 workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22pub enum RouteResult {
24 Query(Value),
26 Mutation(Value),
28 Job(Value),
30 Workflow(Value),
32}
33
34pub struct FunctionRouter {
36 registry: Arc<FunctionRegistry>,
37 db: Database,
38 http_client: CircuitBreakerClient,
39 job_dispatcher: Option<Arc<dyn JobDispatch>>,
40 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
41 rate_limiter: HybridRateLimiter,
42 query_cache: QueryCache,
43 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
44}
45
46impl FunctionRouter {
47 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
49 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
50 Self {
51 registry,
52 db,
53 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
54 job_dispatcher: None,
55 workflow_dispatcher: None,
56 rate_limiter,
57 query_cache: QueryCache::new(),
58 token_issuer: None,
59 }
60 }
61
62 pub fn with_http_client(
64 registry: Arc<FunctionRegistry>,
65 db: Database,
66 http_client: CircuitBreakerClient,
67 ) -> Self {
68 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
69 Self {
70 registry,
71 db,
72 http_client,
73 job_dispatcher: None,
74 workflow_dispatcher: None,
75 rate_limiter,
76 query_cache: QueryCache::new(),
77 token_issuer: None,
78 }
79 }
80
81 pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
83 self.token_issuer = Some(issuer);
84 self
85 }
86
87 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
89 self.job_dispatcher = Some(dispatcher);
90 self
91 }
92
93 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
95 self.workflow_dispatcher = Some(dispatcher);
96 self
97 }
98
99 pub async fn route(
100 &self,
101 function_name: &str,
102 args: Value,
103 auth: AuthContext,
104 request: RequestMetadata,
105 ) -> Result<RouteResult> {
106 if let Some(entry) = self.registry.get(function_name) {
107 self.check_auth(entry.info(), &auth)?;
108 self.check_rate_limit(entry.info(), function_name, &auth, &request)
109 .await?;
110 let enforce = !entry.info().is_public && entry.info().has_input_args;
113 auth.check_identity_args(function_name, &args, enforce)?;
114
115 return match entry {
116 FunctionEntry::Query { handler, info, .. } => {
117 let pool = if info.consistent {
118 self.db.primary().clone()
119 } else {
120 self.db.read_pool().clone()
121 };
122
123 let auth_scope = Self::auth_cache_scope(&auth);
124 if let Some(ttl) = info.cache_ttl {
125 if let Some(cached) =
126 self.query_cache
127 .get(function_name, &args, auth_scope.as_deref())
128 {
129 return Ok(RouteResult::Query(Value::clone(&cached)));
130 }
131
132 let ctx = QueryContext::new(pool, auth, request);
133 let result = handler(&ctx, args.clone()).await?;
134
135 self.query_cache.set(
136 function_name,
137 &args,
138 auth_scope.as_deref(),
139 result.clone(),
140 Duration::from_secs(ttl),
141 );
142
143 Ok(RouteResult::Query(result))
144 } else {
145 let ctx = QueryContext::new(pool, auth, request);
146 let result = handler(&ctx, args).await?;
147 Ok(RouteResult::Query(result))
148 }
149 }
150 FunctionEntry::Mutation { handler, info } => {
151 if info.transactional {
152 self.execute_transactional(handler, args, auth, request)
153 .await
154 } else {
155 let mut ctx = MutationContext::with_dispatch(
157 self.db.primary().clone(),
158 auth,
159 request,
160 self.http_client.clone(),
161 self.job_dispatcher.clone(),
162 self.workflow_dispatcher.clone(),
163 );
164 if let Some(ref issuer) = self.token_issuer {
165 ctx.set_token_issuer(issuer.clone());
166 }
167 let result = handler(&ctx, args).await?;
168 Ok(RouteResult::Mutation(result))
169 }
170 }
171 };
172 }
173
174 if let Some(ref job_dispatcher) = self.job_dispatcher
175 && let Some(job_info) = job_dispatcher.get_info(function_name)
176 {
177 self.check_job_auth(&job_info, &auth)?;
178 auth.check_identity_args(function_name, &args, !job_info.is_public)?;
179 match job_dispatcher
180 .dispatch_by_name(function_name, args.clone(), auth.principal_id())
181 .await
182 {
183 Ok(job_id) => {
184 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
185 }
186 Err(ForgeError::NotFound(_)) => {}
187 Err(e) => return Err(e),
188 }
189 }
190
191 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
192 && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
193 {
194 self.check_workflow_auth(&workflow_info, &auth)?;
195 auth.check_identity_args(function_name, &args, !workflow_info.is_public)?;
196 match workflow_dispatcher
197 .start_by_name(function_name, args.clone(), auth.principal_id())
198 .await
199 {
200 Ok(workflow_id) => {
201 return Ok(RouteResult::Workflow(
202 serde_json::json!({ "workflow_id": workflow_id }),
203 ));
204 }
205 Err(ForgeError::NotFound(_)) => {}
206 Err(e) => return Err(e),
207 }
208 }
209
210 Err(ForgeError::NotFound(format!(
211 "Function '{}' not found",
212 function_name
213 )))
214 }
215
216 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
217 if info.is_public {
218 return Ok(());
219 }
220
221 if !auth.is_authenticated() {
222 return Err(ForgeError::Unauthorized("Authentication required".into()));
223 }
224
225 if let Some(role) = info.required_role
226 && !auth.has_role(role)
227 {
228 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
229 }
230
231 Ok(())
232 }
233
234 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
235 if info.is_public {
236 return Ok(());
237 }
238
239 if !auth.is_authenticated() {
240 return Err(ForgeError::Unauthorized("Authentication required".into()));
241 }
242
243 if let Some(role) = info.required_role
244 && !auth.has_role(role)
245 {
246 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
247 }
248
249 Ok(())
250 }
251
252 fn check_workflow_auth(
253 &self,
254 info: &forge_core::workflow::WorkflowInfo,
255 auth: &AuthContext,
256 ) -> Result<()> {
257 if info.is_public {
258 return Ok(());
259 }
260
261 if !auth.is_authenticated() {
262 return Err(ForgeError::Unauthorized("Authentication required".into()));
263 }
264
265 if let Some(role) = info.required_role
266 && !auth.has_role(role)
267 {
268 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
269 }
270
271 Ok(())
272 }
273
274 async fn check_rate_limit(
276 &self,
277 info: &FunctionInfo,
278 function_name: &str,
279 auth: &AuthContext,
280 request: &RequestMetadata,
281 ) -> Result<()> {
282 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
284 (Some(r), Some(p)) => (r, p),
285 _ => return Ok(()),
286 };
287
288 let key_str = info.rate_limit_key.unwrap_or("user");
290 let key_type: RateLimitKey = match key_str.parse() {
291 Ok(k) => k,
292 Err(_) => {
293 tracing::warn!(
294 function = %function_name,
295 key = %key_str,
296 "Invalid rate limit key, falling back to 'user'"
297 );
298 RateLimitKey::default()
299 }
300 };
301
302 let config =
303 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
304
305 let bucket_key = self
307 .rate_limiter
308 .build_key(key_type, function_name, auth, request);
309
310 self.rate_limiter.enforce(&bucket_key, &config).await?;
312
313 Ok(())
314 }
315
316 fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
317 if !auth.is_authenticated() {
318 return Some("anon".to_string());
319 }
320
321 let mut roles = auth.roles().to_vec();
323 roles.sort();
324 roles.dedup();
325
326 let mut claims = BTreeMap::new();
327 for (k, v) in auth.claims() {
328 claims.insert(k.clone(), v.clone());
329 }
330
331 use std::collections::hash_map::DefaultHasher;
332 use std::hash::{Hash, Hasher};
333
334 let mut hasher = DefaultHasher::new();
335 roles.hash(&mut hasher);
336 serde_json::to_string(&claims)
337 .unwrap_or_default()
338 .hash(&mut hasher);
339
340 let principal = auth
341 .principal_id()
342 .unwrap_or_else(|| "authenticated".to_string());
343
344 Some(format!(
345 "subject:{principal}:scope:{:016x}",
346 hasher.finish()
347 ))
348 }
349
350 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
352 self.registry.get(function_name).map(|e| e.kind())
353 }
354
355 pub fn has_function(&self, function_name: &str) -> bool {
357 self.registry.get(function_name).is_some()
358 }
359
360 async fn execute_transactional(
361 &self,
362 handler: &BoxedMutationFn,
363 args: Value,
364 auth: AuthContext,
365 request: RequestMetadata,
366 ) -> Result<RouteResult> {
367 let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
368
369 async {
370 let primary = self.db.primary();
371 let tx = primary
372 .begin()
373 .await
374 .map_err(|e| ForgeError::Database(e.to_string()))?;
375
376 let job_dispatcher = self.job_dispatcher.clone();
377 let job_lookup: forge_core::JobInfoLookup =
378 Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
379
380 let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
381 primary.clone(),
382 tx,
383 auth,
384 request,
385 self.http_client.clone(),
386 job_lookup,
387 );
388 if let Some(ref issuer) = self.token_issuer {
389 ctx.set_token_issuer(issuer.clone());
390 }
391
392 match handler(&ctx, args).await {
393 Ok(value) => {
394 let buffer = {
395 let guard = outbox.lock().expect("outbox mutex poisoned");
396 OutboxBuffer {
397 jobs: guard.jobs.clone(),
398 workflows: guard.workflows.clone(),
399 }
400 };
401
402 let mut tx = Arc::try_unwrap(tx_handle)
403 .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
404 .into_inner();
405
406 for job in &buffer.jobs {
407 Self::insert_job(&mut tx, job).await?;
408 }
409
410 for workflow in &buffer.workflows {
411 let version = self
412 .workflow_dispatcher
413 .as_ref()
414 .and_then(|d| d.get_info(&workflow.workflow_name))
415 .map(|info| info.version)
416 .ok_or_else(|| {
417 ForgeError::NotFound(format!(
418 "Workflow '{}' not found",
419 workflow.workflow_name
420 ))
421 })?;
422 Self::insert_workflow(&mut tx, workflow, version).await?;
423 }
424
425 tx.commit()
426 .await
427 .map_err(|e| ForgeError::Database(e.to_string()))?;
428
429 Ok(RouteResult::Mutation(value))
430 }
431 Err(e) => Err(e),
432 }
433 }
434 .instrument(span)
435 .await
436 }
437
438 async fn insert_job(
439 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
440 job: &PendingJob,
441 ) -> Result<()> {
442 let now = Utc::now();
443 sqlx::query(
444 r#"
445 INSERT INTO forge_jobs (
446 id, job_type, input, job_context, status, priority, attempts, max_attempts,
447 worker_capability, owner_subject, scheduled_at, created_at
448 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
449 "#,
450 )
451 .bind(job.id)
452 .bind(&job.job_type)
453 .bind(&job.args)
454 .bind(&job.context)
455 .bind(JobStatus::Pending.as_str())
456 .bind(job.priority)
457 .bind(0i32)
458 .bind(job.max_attempts)
459 .bind(&job.worker_capability)
460 .bind(&job.owner_subject)
461 .bind(now)
462 .bind(now)
463 .execute(&mut **tx)
464 .await
465 .map_err(|e| ForgeError::Database(e.to_string()))?;
466
467 Ok(())
468 }
469
470 async fn insert_workflow(
471 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
472 workflow: &PendingWorkflow,
473 version: u32,
474 ) -> Result<()> {
475 let now = Utc::now();
476 sqlx::query(
477 r#"
478 INSERT INTO forge_workflow_runs (
479 id, workflow_name, version, owner_subject, input, status, current_step,
480 step_results, started_at, trace_id
481 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
482 "#,
483 )
484 .bind(workflow.id)
485 .bind(&workflow.workflow_name)
486 .bind(version as i32)
487 .bind(&workflow.owner_subject)
488 .bind(&workflow.input)
489 .bind(WorkflowStatus::Created.as_str())
490 .bind(Option::<String>::None)
491 .bind(serde_json::json!({}))
492 .bind(now)
493 .bind(workflow.id.to_string())
494 .execute(&mut **tx)
495 .await
496 .map_err(|e| ForgeError::Database(e.to_string()))?;
497
498 Ok(())
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use std::collections::HashMap;
506
507 #[test]
508 fn test_check_auth_public() {
509 let info = FunctionInfo {
510 name: "test",
511 description: None,
512 kind: FunctionKind::Query,
513 required_role: None,
514 is_public: true,
515 cache_ttl: None,
516 timeout: None,
517 rate_limit_requests: None,
518 rate_limit_per_secs: None,
519 rate_limit_key: None,
520 log_level: None,
521 table_dependencies: &[],
522 selected_columns: &[],
523 transactional: false,
524 consistent: false,
525 has_input_args: false,
526 };
527
528 let _auth = AuthContext::unauthenticated();
529
530 assert!(info.is_public);
533 }
534
535 #[test]
536 fn test_identity_args_reject_cross_user_value() {
537 let user_id = uuid::Uuid::new_v4();
538 let auth = AuthContext::authenticated(
539 user_id,
540 vec!["user".to_string()],
541 HashMap::from([(
542 "sub".to_string(),
543 serde_json::Value::String(user_id.to_string()),
544 )]),
545 );
546 let args = serde_json::json!({
547 "user_id": uuid::Uuid::new_v4().to_string()
548 });
549
550 let result = auth.check_identity_args("list_orders", &args, true);
551 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
552 }
553
554 #[test]
555 fn test_identity_args_allow_matching_subject() {
556 let sub = "user_123";
557 let auth = AuthContext::authenticated_without_uuid(
558 vec!["user".to_string()],
559 HashMap::from([(
560 "sub".to_string(),
561 serde_json::Value::String(sub.to_string()),
562 )]),
563 );
564 let args = serde_json::json!({
565 "subject": sub
566 });
567
568 let result = auth.check_identity_args("list_orders", &args, true);
569 assert!(result.is_ok());
570 }
571
572 #[test]
573 fn test_identity_args_require_auth_for_identity_keys() {
574 let auth = AuthContext::unauthenticated();
575 let args = serde_json::json!({
576 "user_id": uuid::Uuid::new_v4().to_string()
577 });
578
579 let result = auth.check_identity_args("list_orders", &args, true);
580 assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
581 }
582
583 #[test]
584 fn test_identity_args_require_scope_for_non_public_calls() {
585 let user_id = uuid::Uuid::new_v4();
586 let auth = AuthContext::authenticated(
587 user_id,
588 vec!["user".to_string()],
589 HashMap::from([(
590 "sub".to_string(),
591 serde_json::Value::String(user_id.to_string()),
592 )]),
593 );
594
595 let result = auth.check_identity_args("list_orders", &serde_json::json!({}), true);
596 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
597 }
598
599 #[test]
600 fn test_identity_args_skip_scope_for_no_input_functions() {
601 let user_id = uuid::Uuid::new_v4();
602 let auth = AuthContext::authenticated(
603 user_id,
604 vec!["user".to_string()],
605 HashMap::from([(
606 "sub".to_string(),
607 serde_json::Value::String(user_id.to_string()),
608 )]),
609 );
610
611 let result = auth.check_identity_args("list_todos", &serde_json::Value::Null, false);
613 assert!(result.is_ok());
614 }
615
616 #[test]
617 fn test_auth_cache_scope_changes_with_claims() {
618 let user_id = uuid::Uuid::new_v4();
619 let auth_a = AuthContext::authenticated(
620 user_id,
621 vec!["user".to_string()],
622 HashMap::from([
623 (
624 "sub".to_string(),
625 serde_json::Value::String(user_id.to_string()),
626 ),
627 (
628 "tenant_id".to_string(),
629 serde_json::Value::String("tenant-a".to_string()),
630 ),
631 ]),
632 );
633 let auth_b = AuthContext::authenticated(
634 user_id,
635 vec!["user".to_string()],
636 HashMap::from([
637 (
638 "sub".to_string(),
639 serde_json::Value::String(user_id.to_string()),
640 ),
641 (
642 "tenant_id".to_string(),
643 serde_json::Value::String("tenant-b".to_string()),
644 ),
645 ]),
646 );
647
648 let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
649 let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
650 assert_ne!(scope_a, scope_b);
651 }
652}