1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8 MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9 Result, WorkflowDispatch,
10 job::JobStatus,
11 rate_limit::{RateLimitConfig, RateLimitKey},
12 workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22fn require_auth(is_public: bool, required_role: Option<&str>, auth: &AuthContext) -> Result<()> {
24 if is_public {
25 return Ok(());
26 }
27 if !auth.is_authenticated() {
28 return Err(ForgeError::Unauthorized("Authentication required".into()));
29 }
30 if let Some(role) = required_role
31 && !auth.has_role(role)
32 {
33 return Err(ForgeError::Forbidden(format!("Role '{role}' required")));
34 }
35 Ok(())
36}
37
38pub enum RouteResult {
40 Query(Value),
42 Mutation(Value),
44 Job(Value),
46 Workflow(Value),
48}
49
50pub struct FunctionRouter {
52 registry: Arc<FunctionRegistry>,
53 db: Database,
54 http_client: CircuitBreakerClient,
55 job_dispatcher: Option<Arc<dyn JobDispatch>>,
56 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
57 rate_limiter: HybridRateLimiter,
58 query_cache: QueryCache,
59 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
60 token_ttl: forge_core::AuthTokenTtl,
61}
62
63impl FunctionRouter {
64 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
66 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
67 Self {
68 registry,
69 db,
70 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
71 job_dispatcher: None,
72 workflow_dispatcher: None,
73 rate_limiter,
74 query_cache: QueryCache::new(),
75 token_issuer: None,
76 token_ttl: forge_core::AuthTokenTtl::default(),
77 }
78 }
79
80 pub fn with_http_client(
82 registry: Arc<FunctionRegistry>,
83 db: Database,
84 http_client: CircuitBreakerClient,
85 ) -> Self {
86 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
87 Self {
88 registry,
89 db,
90 http_client,
91 job_dispatcher: None,
92 workflow_dispatcher: None,
93 rate_limiter,
94 query_cache: QueryCache::new(),
95 token_issuer: None,
96 token_ttl: forge_core::AuthTokenTtl::default(),
97 }
98 }
99
100 pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
102 self.token_issuer = Some(issuer);
103 self
104 }
105
106 pub fn with_token_ttl(mut self, ttl: forge_core::AuthTokenTtl) -> Self {
108 self.token_ttl = ttl;
109 self
110 }
111
112 pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
114 self.token_ttl = ttl;
115 }
116
117 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
119 self.job_dispatcher = Some(dispatcher);
120 self
121 }
122
123 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
125 self.workflow_dispatcher = Some(dispatcher);
126 self
127 }
128
129 pub async fn route(
130 &self,
131 function_name: &str,
132 args: Value,
133 auth: AuthContext,
134 request: RequestMetadata,
135 ) -> Result<RouteResult> {
136 if let Some(entry) = self.registry.get(function_name) {
137 self.check_auth(entry.info(), &auth)?;
138 self.check_rate_limit(entry.info(), function_name, &auth, &request)
139 .await?;
140 let enforce = !entry.info().is_public && entry.info().has_input_args;
143 auth.check_identity_args(function_name, &args, enforce)?;
144
145 return match entry {
146 FunctionEntry::Query { handler, info, .. } => {
147 let pool = if info.consistent {
148 self.db.primary().clone()
149 } else {
150 self.db.read_pool().clone()
151 };
152
153 let auth_scope = Self::auth_cache_scope(&auth);
154 if let Some(ttl) = info.cache_ttl {
155 if let Some(cached) =
156 self.query_cache
157 .get(function_name, &args, auth_scope.as_deref())
158 {
159 return Ok(RouteResult::Query(Value::clone(&cached)));
160 }
161
162 let ctx = QueryContext::new(pool, auth, request);
163 let result = handler(&ctx, args.clone()).await?;
164
165 self.query_cache.set(
166 function_name,
167 &args,
168 auth_scope.as_deref(),
169 result.clone(),
170 Duration::from_secs(ttl),
171 );
172
173 Ok(RouteResult::Query(result))
174 } else {
175 let ctx = QueryContext::new(pool, auth, request);
176 let result = handler(&ctx, args).await?;
177 Ok(RouteResult::Query(result))
178 }
179 }
180 FunctionEntry::Mutation { handler, info } => {
181 if info.transactional {
182 self.execute_transactional(info, handler, args, auth, request)
183 .await
184 } else {
185 let mut ctx = MutationContext::with_dispatch(
187 self.db.primary().clone(),
188 auth,
189 request,
190 self.http_client.clone(),
191 self.job_dispatcher.clone(),
192 self.workflow_dispatcher.clone(),
193 );
194 if let Some(ref issuer) = self.token_issuer {
195 ctx.set_token_issuer(issuer.clone());
196 }
197 ctx.set_token_ttl(self.token_ttl.clone());
198 ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
199 let result = handler(&ctx, args).await?;
200 Ok(RouteResult::Mutation(result))
201 }
202 }
203 };
204 }
205
206 if let Some(ref job_dispatcher) = self.job_dispatcher
207 && let Some(job_info) = job_dispatcher.get_info(function_name)
208 {
209 self.check_job_auth(&job_info, &auth)?;
210 auth.check_identity_args(function_name, &args, !job_info.is_public)?;
211 match job_dispatcher
212 .dispatch_by_name(function_name, args.clone(), auth.principal_id())
213 .await
214 {
215 Ok(job_id) => {
216 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
217 }
218 Err(ForgeError::NotFound(_)) => {}
219 Err(e) => return Err(e),
220 }
221 }
222
223 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
224 && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
225 {
226 self.check_workflow_auth(&workflow_info, &auth)?;
227 auth.check_identity_args(function_name, &args, !workflow_info.is_public)?;
228 match workflow_dispatcher
229 .start_by_name(function_name, args.clone(), auth.principal_id())
230 .await
231 {
232 Ok(workflow_id) => {
233 return Ok(RouteResult::Workflow(
234 serde_json::json!({ "workflow_id": workflow_id }),
235 ));
236 }
237 Err(ForgeError::NotFound(_)) => {}
238 Err(e) => return Err(e),
239 }
240 }
241
242 Err(ForgeError::NotFound(format!(
243 "Function '{}' not found",
244 function_name
245 )))
246 }
247
248 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
249 require_auth(info.is_public, info.required_role, auth)
250 }
251
252 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
253 require_auth(info.is_public, info.required_role, auth)
254 }
255
256 fn check_workflow_auth(
257 &self,
258 info: &forge_core::workflow::WorkflowInfo,
259 auth: &AuthContext,
260 ) -> Result<()> {
261 require_auth(info.is_public, info.required_role, auth)
262 }
263
264 async fn check_rate_limit(
266 &self,
267 info: &FunctionInfo,
268 function_name: &str,
269 auth: &AuthContext,
270 request: &RequestMetadata,
271 ) -> Result<()> {
272 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
274 (Some(r), Some(p)) => (r, p),
275 _ => return Ok(()),
276 };
277
278 let key_str = info.rate_limit_key.unwrap_or("user");
280 let key_type: RateLimitKey = match key_str.parse() {
281 Ok(k) => k,
282 Err(_) => {
283 tracing::error!(
284 function = %function_name,
285 key = %key_str,
286 "Invalid rate limit key, falling back to 'user'"
287 );
288 RateLimitKey::default()
289 }
290 };
291
292 let config =
293 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
294
295 let bucket_key = self
297 .rate_limiter
298 .build_key(key_type, function_name, auth, request);
299
300 self.rate_limiter.enforce(&bucket_key, &config).await?;
302
303 Ok(())
304 }
305
306 fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
307 if !auth.is_authenticated() {
308 return Some("anon".to_string());
309 }
310
311 let mut roles = auth.roles().to_vec();
313 roles.sort();
314 roles.dedup();
315
316 let mut claims = BTreeMap::new();
317 for (k, v) in auth.claims() {
318 claims.insert(k.clone(), v.clone());
319 }
320
321 use std::collections::hash_map::DefaultHasher;
322 use std::hash::{Hash, Hasher};
323
324 let mut hasher = DefaultHasher::new();
325 roles.hash(&mut hasher);
326 serde_json::to_string(&claims)
327 .unwrap_or_default()
328 .hash(&mut hasher);
329
330 let principal = auth
331 .principal_id()
332 .unwrap_or_else(|| "authenticated".to_string());
333
334 Some(format!(
335 "subject:{principal}:scope:{:016x}",
336 hasher.finish()
337 ))
338 }
339
340 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
342 self.registry.get(function_name).map(|e| e.kind())
343 }
344
345 pub fn has_function(&self, function_name: &str) -> bool {
347 self.registry.get(function_name).is_some()
348 }
349
350 async fn execute_transactional(
351 &self,
352 info: &FunctionInfo,
353 handler: &BoxedMutationFn,
354 args: Value,
355 auth: AuthContext,
356 request: RequestMetadata,
357 ) -> Result<RouteResult> {
358 let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
359
360 async {
361 let primary = self.db.primary();
362 let tx = primary
363 .begin()
364 .await
365 .map_err(|e| ForgeError::Database(e.to_string()))?;
366
367 let job_dispatcher = self.job_dispatcher.clone();
368 let job_lookup: forge_core::JobInfoLookup =
369 Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
370
371 let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
372 primary.clone(),
373 tx,
374 auth,
375 request,
376 self.http_client.clone(),
377 job_lookup,
378 );
379 if let Some(ref issuer) = self.token_issuer {
380 ctx.set_token_issuer(issuer.clone());
381 }
382 ctx.set_token_ttl(self.token_ttl.clone());
383 ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
384
385 match handler(&ctx, args).await {
386 Ok(value) => {
387 let buffer = {
388 let guard = outbox.lock().unwrap_or_else(|poisoned| {
389 tracing::error!("Outbox mutex was poisoned, recovering");
390 poisoned.into_inner()
391 });
392 OutboxBuffer {
393 jobs: guard.jobs.clone(),
394 workflows: guard.workflows.clone(),
395 }
396 };
397
398 let mut tx = Arc::try_unwrap(tx_handle)
399 .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
400 .into_inner();
401
402 for job in &buffer.jobs {
403 Self::insert_job(&mut tx, job).await?;
404 }
405
406 for workflow in &buffer.workflows {
407 if self
408 .workflow_dispatcher
409 .as_ref()
410 .and_then(|d| d.get_info(&workflow.workflow_name))
411 .is_none()
412 {
413 return Err(ForgeError::NotFound(format!(
414 "Workflow '{}' not found",
415 workflow.workflow_name
416 )));
417 }
418 Self::insert_workflow(&mut tx, workflow).await?;
419 }
420
421 tx.commit()
422 .await
423 .map_err(|e| ForgeError::Database(e.to_string()))?;
424
425 Ok(RouteResult::Mutation(value))
426 }
427 Err(e) => Err(e),
428 }
429 }
430 .instrument(span)
431 .await
432 }
433
434 async fn insert_job(
435 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
436 job: &PendingJob,
437 ) -> Result<()> {
438 let now = Utc::now();
439 sqlx::query!(
440 r#"
441 INSERT INTO forge_jobs (
442 id, job_type, input, job_context, status, priority, attempts, max_attempts,
443 worker_capability, owner_subject, scheduled_at, created_at
444 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
445 "#,
446 job.id,
447 &job.job_type,
448 job.args as _,
449 job.context as _,
450 JobStatus::Pending.as_str(),
451 job.priority,
452 0i32,
453 job.max_attempts,
454 job.worker_capability.as_deref(),
455 job.owner_subject as _,
456 now,
457 now,
458 )
459 .execute(&mut **tx)
460 .await
461 .map_err(|e| ForgeError::Database(e.to_string()))?;
462
463 Ok(())
464 }
465
466 async fn insert_workflow(
467 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
468 workflow: &PendingWorkflow,
469 ) -> Result<()> {
470 let now = Utc::now();
471 sqlx::query!(
472 r#"
473 INSERT INTO forge_workflow_runs (
474 id, workflow_name, owner_subject, input, status, current_step,
475 step_results, started_at, trace_id
476 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
477 "#,
478 workflow.id,
479 &workflow.workflow_name,
480 workflow.owner_subject as _,
481 workflow.input as _,
482 WorkflowStatus::Created.as_str(),
483 Option::<String>::None,
484 serde_json::json!({}) as _,
485 now,
486 workflow.id.to_string(),
487 )
488 .execute(&mut **tx)
489 .await
490 .map_err(|e| ForgeError::Database(e.to_string()))?;
491
492 Ok(())
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use std::collections::HashMap;
500
501 #[test]
502 fn test_check_auth_public() {
503 let info = FunctionInfo {
504 name: "test",
505 description: None,
506 kind: FunctionKind::Query,
507 required_role: None,
508 is_public: true,
509 cache_ttl: None,
510 timeout: None,
511 http_timeout: None,
512 rate_limit_requests: None,
513 rate_limit_per_secs: None,
514 rate_limit_key: None,
515 log_level: None,
516 table_dependencies: &[],
517 selected_columns: &[],
518 transactional: false,
519 consistent: false,
520 has_input_args: false,
521 };
522
523 let _auth = AuthContext::unauthenticated();
524
525 assert!(info.is_public);
528 }
529
530 #[test]
531 fn test_identity_args_reject_cross_user_value() {
532 let user_id = uuid::Uuid::new_v4();
533 let auth = AuthContext::authenticated(
534 user_id,
535 vec!["user".to_string()],
536 HashMap::from([(
537 "sub".to_string(),
538 serde_json::Value::String(user_id.to_string()),
539 )]),
540 );
541 let args = serde_json::json!({
542 "user_id": uuid::Uuid::new_v4().to_string()
543 });
544
545 let result = auth.check_identity_args("list_orders", &args, true);
546 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
547 }
548
549 #[test]
550 fn test_identity_args_allow_matching_subject() {
551 let sub = "user_123";
552 let auth = AuthContext::authenticated_without_uuid(
553 vec!["user".to_string()],
554 HashMap::from([(
555 "sub".to_string(),
556 serde_json::Value::String(sub.to_string()),
557 )]),
558 );
559 let args = serde_json::json!({
560 "subject": sub
561 });
562
563 let result = auth.check_identity_args("list_orders", &args, true);
564 assert!(result.is_ok());
565 }
566
567 #[test]
568 fn test_identity_args_require_auth_for_identity_keys() {
569 let auth = AuthContext::unauthenticated();
570 let args = serde_json::json!({
571 "user_id": uuid::Uuid::new_v4().to_string()
572 });
573
574 let result = auth.check_identity_args("list_orders", &args, true);
575 assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
576 }
577
578 #[test]
579 fn test_identity_args_require_scope_for_non_public_calls() {
580 let user_id = uuid::Uuid::new_v4();
581 let auth = AuthContext::authenticated(
582 user_id,
583 vec!["user".to_string()],
584 HashMap::from([(
585 "sub".to_string(),
586 serde_json::Value::String(user_id.to_string()),
587 )]),
588 );
589
590 let result = auth.check_identity_args("list_orders", &serde_json::json!({}), true);
591 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
592 }
593
594 #[test]
595 fn test_identity_args_skip_scope_for_no_input_functions() {
596 let user_id = uuid::Uuid::new_v4();
597 let auth = AuthContext::authenticated(
598 user_id,
599 vec!["user".to_string()],
600 HashMap::from([(
601 "sub".to_string(),
602 serde_json::Value::String(user_id.to_string()),
603 )]),
604 );
605
606 let result = auth.check_identity_args("list_todos", &serde_json::Value::Null, false);
608 assert!(result.is_ok());
609 }
610
611 #[test]
612 fn test_auth_cache_scope_changes_with_claims() {
613 let user_id = uuid::Uuid::new_v4();
614 let auth_a = AuthContext::authenticated(
615 user_id,
616 vec!["user".to_string()],
617 HashMap::from([
618 (
619 "sub".to_string(),
620 serde_json::Value::String(user_id.to_string()),
621 ),
622 (
623 "tenant_id".to_string(),
624 serde_json::Value::String("tenant-a".to_string()),
625 ),
626 ]),
627 );
628 let auth_b = AuthContext::authenticated(
629 user_id,
630 vec!["user".to_string()],
631 HashMap::from([
632 (
633 "sub".to_string(),
634 serde_json::Value::String(user_id.to_string()),
635 ),
636 (
637 "tenant_id".to_string(),
638 serde_json::Value::String("tenant-b".to_string()),
639 ),
640 ]),
641 );
642
643 let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
644 let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
645 assert_ne!(scope_a, scope_b);
646 }
647}