1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8 MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9 Result, WorkflowDispatch,
10 job::JobStatus,
11 rate_limit::{RateLimitConfig, RateLimitKey},
12 workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22fn require_auth(is_public: bool, required_role: Option<&str>, auth: &AuthContext) -> Result<()> {
24 if is_public {
25 return Ok(());
26 }
27 if !auth.is_authenticated() {
28 return Err(ForgeError::Unauthorized("Authentication required".into()));
29 }
30 if let Some(role) = required_role
31 && !auth.has_role(role)
32 {
33 return Err(ForgeError::Forbidden(format!("Role '{role}' required")));
34 }
35 Ok(())
36}
37
38pub enum RouteResult {
40 Query(Value),
42 Mutation(Value),
44 Job(Value),
46 Workflow(Value),
48}
49
50pub struct FunctionRouter {
52 registry: Arc<FunctionRegistry>,
53 db: Database,
54 http_client: CircuitBreakerClient,
55 job_dispatcher: Option<Arc<dyn JobDispatch>>,
56 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
57 rate_limiter: HybridRateLimiter,
58 query_cache: QueryCache,
59 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
60 token_ttl: forge_core::AuthTokenTtl,
61}
62
63impl FunctionRouter {
64 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
66 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
67 Self {
68 registry,
69 db,
70 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
71 job_dispatcher: None,
72 workflow_dispatcher: None,
73 rate_limiter,
74 query_cache: QueryCache::new(),
75 token_issuer: None,
76 token_ttl: forge_core::AuthTokenTtl::default(),
77 }
78 }
79
80 pub fn with_http_client(
82 registry: Arc<FunctionRegistry>,
83 db: Database,
84 http_client: CircuitBreakerClient,
85 ) -> Self {
86 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
87 Self {
88 registry,
89 db,
90 http_client,
91 job_dispatcher: None,
92 workflow_dispatcher: None,
93 rate_limiter,
94 query_cache: QueryCache::new(),
95 token_issuer: None,
96 token_ttl: forge_core::AuthTokenTtl::default(),
97 }
98 }
99
100 pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
102 self.token_issuer = Some(issuer);
103 self
104 }
105
106 pub fn with_token_ttl(mut self, ttl: forge_core::AuthTokenTtl) -> Self {
108 self.token_ttl = ttl;
109 self
110 }
111
112 pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
114 self.token_ttl = ttl;
115 }
116
117 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
119 self.job_dispatcher = Some(dispatcher);
120 self
121 }
122
123 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
125 self.workflow_dispatcher = Some(dispatcher);
126 self
127 }
128
129 pub async fn route(
130 &self,
131 function_name: &str,
132 args: Value,
133 auth: AuthContext,
134 request: RequestMetadata,
135 ) -> Result<RouteResult> {
136 if let Some(entry) = self.registry.get(function_name) {
137 self.check_auth(entry.info(), &auth)?;
138 self.check_rate_limit(entry.info(), function_name, &auth, &request)
139 .await?;
140 let enforce = !entry.info().is_public && entry.info().has_input_args;
143 auth.check_identity_args(function_name, &args, enforce)?;
144
145 return match entry {
146 FunctionEntry::Query { handler, info, .. } => {
147 let pool = if info.consistent {
148 self.db.primary().clone()
149 } else {
150 self.db.read_pool().clone()
151 };
152
153 let auth_scope = Self::auth_cache_scope(&auth);
154 if let Some(ttl) = info.cache_ttl {
155 if let Some(cached) =
156 self.query_cache
157 .get(function_name, &args, auth_scope.as_deref())
158 {
159 return Ok(RouteResult::Query(Value::clone(&cached)));
160 }
161
162 let ctx = QueryContext::new(pool, auth, request);
163 let result = handler(&ctx, args.clone()).await?;
164
165 self.query_cache.set(
166 function_name,
167 &args,
168 auth_scope.as_deref(),
169 result.clone(),
170 Duration::from_secs(ttl),
171 );
172
173 Ok(RouteResult::Query(result))
174 } else {
175 let ctx = QueryContext::new(pool, auth, request);
176 let result = handler(&ctx, args).await?;
177 Ok(RouteResult::Query(result))
178 }
179 }
180 FunctionEntry::Mutation { handler, info } => {
181 if info.transactional {
182 self.execute_transactional(info, handler, args, auth, request)
183 .await
184 } else {
185 let mut ctx = MutationContext::with_dispatch(
187 self.db.primary().clone(),
188 auth,
189 request,
190 self.http_client.clone(),
191 self.job_dispatcher.clone(),
192 self.workflow_dispatcher.clone(),
193 );
194 if let Some(ref issuer) = self.token_issuer {
195 ctx.set_token_issuer(issuer.clone());
196 }
197 ctx.set_token_ttl(self.token_ttl.clone());
198 ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
199 let result = handler(&ctx, args).await?;
200 Ok(RouteResult::Mutation(result))
201 }
202 }
203 };
204 }
205
206 if let Some(ref job_dispatcher) = self.job_dispatcher
207 && let Some(job_info) = job_dispatcher.get_info(function_name)
208 {
209 self.check_job_auth(&job_info, &auth)?;
210 auth.check_identity_args(function_name, &args, !job_info.is_public)?;
211 match job_dispatcher
212 .dispatch_by_name(function_name, args.clone(), auth.principal_id())
213 .await
214 {
215 Ok(job_id) => {
216 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
217 }
218 Err(ForgeError::NotFound(_)) => {}
219 Err(e) => return Err(e),
220 }
221 }
222
223 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
224 && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
225 {
226 self.check_workflow_auth(&workflow_info, &auth)?;
227 auth.check_identity_args(function_name, &args, !workflow_info.is_public)?;
228 match workflow_dispatcher
229 .start_by_name(function_name, args.clone(), auth.principal_id())
230 .await
231 {
232 Ok(workflow_id) => {
233 return Ok(RouteResult::Workflow(
234 serde_json::json!({ "workflow_id": workflow_id }),
235 ));
236 }
237 Err(ForgeError::NotFound(_)) => {}
238 Err(e) => return Err(e),
239 }
240 }
241
242 Err(ForgeError::NotFound(format!(
243 "Function '{}' not found",
244 function_name
245 )))
246 }
247
248 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
249 require_auth(info.is_public, info.required_role, auth)
250 }
251
252 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
253 require_auth(info.is_public, info.required_role, auth)
254 }
255
256 fn check_workflow_auth(
257 &self,
258 info: &forge_core::workflow::WorkflowInfo,
259 auth: &AuthContext,
260 ) -> Result<()> {
261 require_auth(info.is_public, info.required_role, auth)
262 }
263
264 async fn check_rate_limit(
266 &self,
267 info: &FunctionInfo,
268 function_name: &str,
269 auth: &AuthContext,
270 request: &RequestMetadata,
271 ) -> Result<()> {
272 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
274 (Some(r), Some(p)) => (r, p),
275 _ => return Ok(()),
276 };
277
278 let key_str = info.rate_limit_key.unwrap_or("user");
280 let key_type: RateLimitKey = match key_str.parse() {
281 Ok(k) => k,
282 Err(_) => {
283 tracing::error!(
284 function = %function_name,
285 key = %key_str,
286 "Invalid rate limit key, falling back to 'user'"
287 );
288 RateLimitKey::default()
289 }
290 };
291
292 let config =
293 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
294
295 let bucket_key = self
297 .rate_limiter
298 .build_key(key_type, function_name, auth, request);
299
300 self.rate_limiter.enforce(&bucket_key, &config).await?;
302
303 Ok(())
304 }
305
306 fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
307 if !auth.is_authenticated() {
308 return Some("anon".to_string());
309 }
310
311 let mut roles = auth.roles().to_vec();
313 roles.sort();
314 roles.dedup();
315
316 let mut claims = BTreeMap::new();
317 for (k, v) in auth.claims() {
318 claims.insert(k.clone(), v.clone());
319 }
320
321 use std::collections::hash_map::DefaultHasher;
322 use std::hash::{Hash, Hasher};
323
324 let mut hasher = DefaultHasher::new();
325 roles.hash(&mut hasher);
326 serde_json::to_string(&claims)
327 .unwrap_or_default()
328 .hash(&mut hasher);
329
330 let principal = auth
331 .principal_id()
332 .unwrap_or_else(|| "authenticated".to_string());
333
334 Some(format!(
335 "subject:{principal}:scope:{:016x}",
336 hasher.finish()
337 ))
338 }
339
340 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
342 self.registry.get(function_name).map(|e| e.kind())
343 }
344
345 pub fn has_function(&self, function_name: &str) -> bool {
347 self.registry.get(function_name).is_some()
348 }
349
350 async fn execute_transactional(
351 &self,
352 info: &FunctionInfo,
353 handler: &BoxedMutationFn,
354 args: Value,
355 auth: AuthContext,
356 request: RequestMetadata,
357 ) -> Result<RouteResult> {
358 let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
359
360 async {
361 let primary = self.db.primary();
362 let tx = primary
363 .begin()
364 .await
365 .map_err(|e| ForgeError::Database(e.to_string()))?;
366
367 let job_dispatcher = self.job_dispatcher.clone();
368 let job_lookup: forge_core::JobInfoLookup =
369 Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
370
371 let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
372 primary.clone(),
373 tx,
374 auth,
375 request,
376 self.http_client.clone(),
377 job_lookup,
378 );
379 if let Some(ref issuer) = self.token_issuer {
380 ctx.set_token_issuer(issuer.clone());
381 }
382 ctx.set_token_ttl(self.token_ttl.clone());
383 ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
384
385 match handler(&ctx, args).await {
386 Ok(value) => {
387 let buffer = {
388 let guard = outbox.lock().unwrap_or_else(|poisoned| {
389 tracing::error!("Outbox mutex was poisoned, recovering");
390 poisoned.into_inner()
391 });
392 OutboxBuffer {
393 jobs: guard.jobs.clone(),
394 workflows: guard.workflows.clone(),
395 }
396 };
397
398 let mut tx = Arc::try_unwrap(tx_handle)
399 .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
400 .into_inner();
401
402 for job in &buffer.jobs {
403 Self::insert_job(&mut tx, job).await?;
404 }
405
406 for workflow in &buffer.workflows {
407 let version = self
408 .workflow_dispatcher
409 .as_ref()
410 .and_then(|d| d.get_info(&workflow.workflow_name))
411 .map(|info| info.version)
412 .ok_or_else(|| {
413 ForgeError::NotFound(format!(
414 "Workflow '{}' not found",
415 workflow.workflow_name
416 ))
417 })?;
418 Self::insert_workflow(&mut tx, workflow, version).await?;
419 }
420
421 tx.commit()
422 .await
423 .map_err(|e| ForgeError::Database(e.to_string()))?;
424
425 Ok(RouteResult::Mutation(value))
426 }
427 Err(e) => Err(e),
428 }
429 }
430 .instrument(span)
431 .await
432 }
433
434 async fn insert_job(
435 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
436 job: &PendingJob,
437 ) -> Result<()> {
438 let now = Utc::now();
439 sqlx::query(
440 r#"
441 INSERT INTO forge_jobs (
442 id, job_type, input, job_context, status, priority, attempts, max_attempts,
443 worker_capability, owner_subject, scheduled_at, created_at
444 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
445 "#,
446 )
447 .bind(job.id)
448 .bind(&job.job_type)
449 .bind(&job.args)
450 .bind(&job.context)
451 .bind(JobStatus::Pending.as_str())
452 .bind(job.priority)
453 .bind(0i32)
454 .bind(job.max_attempts)
455 .bind(&job.worker_capability)
456 .bind(&job.owner_subject)
457 .bind(now)
458 .bind(now)
459 .execute(&mut **tx)
460 .await
461 .map_err(|e| ForgeError::Database(e.to_string()))?;
462
463 Ok(())
464 }
465
466 async fn insert_workflow(
467 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
468 workflow: &PendingWorkflow,
469 version: u32,
470 ) -> Result<()> {
471 let now = Utc::now();
472 sqlx::query(
473 r#"
474 INSERT INTO forge_workflow_runs (
475 id, workflow_name, version, owner_subject, input, status, current_step,
476 step_results, started_at, trace_id
477 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
478 "#,
479 )
480 .bind(workflow.id)
481 .bind(&workflow.workflow_name)
482 .bind(version as i32)
483 .bind(&workflow.owner_subject)
484 .bind(&workflow.input)
485 .bind(WorkflowStatus::Created.as_str())
486 .bind(Option::<String>::None)
487 .bind(serde_json::json!({}))
488 .bind(now)
489 .bind(workflow.id.to_string())
490 .execute(&mut **tx)
491 .await
492 .map_err(|e| ForgeError::Database(e.to_string()))?;
493
494 Ok(())
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use std::collections::HashMap;
502
503 #[test]
504 fn test_check_auth_public() {
505 let info = FunctionInfo {
506 name: "test",
507 description: None,
508 kind: FunctionKind::Query,
509 required_role: None,
510 is_public: true,
511 cache_ttl: None,
512 timeout: None,
513 http_timeout: None,
514 rate_limit_requests: None,
515 rate_limit_per_secs: None,
516 rate_limit_key: None,
517 log_level: None,
518 table_dependencies: &[],
519 selected_columns: &[],
520 transactional: false,
521 consistent: false,
522 has_input_args: false,
523 };
524
525 let _auth = AuthContext::unauthenticated();
526
527 assert!(info.is_public);
530 }
531
532 #[test]
533 fn test_identity_args_reject_cross_user_value() {
534 let user_id = uuid::Uuid::new_v4();
535 let auth = AuthContext::authenticated(
536 user_id,
537 vec!["user".to_string()],
538 HashMap::from([(
539 "sub".to_string(),
540 serde_json::Value::String(user_id.to_string()),
541 )]),
542 );
543 let args = serde_json::json!({
544 "user_id": uuid::Uuid::new_v4().to_string()
545 });
546
547 let result = auth.check_identity_args("list_orders", &args, true);
548 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
549 }
550
551 #[test]
552 fn test_identity_args_allow_matching_subject() {
553 let sub = "user_123";
554 let auth = AuthContext::authenticated_without_uuid(
555 vec!["user".to_string()],
556 HashMap::from([(
557 "sub".to_string(),
558 serde_json::Value::String(sub.to_string()),
559 )]),
560 );
561 let args = serde_json::json!({
562 "subject": sub
563 });
564
565 let result = auth.check_identity_args("list_orders", &args, true);
566 assert!(result.is_ok());
567 }
568
569 #[test]
570 fn test_identity_args_require_auth_for_identity_keys() {
571 let auth = AuthContext::unauthenticated();
572 let args = serde_json::json!({
573 "user_id": uuid::Uuid::new_v4().to_string()
574 });
575
576 let result = auth.check_identity_args("list_orders", &args, true);
577 assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
578 }
579
580 #[test]
581 fn test_identity_args_require_scope_for_non_public_calls() {
582 let user_id = uuid::Uuid::new_v4();
583 let auth = AuthContext::authenticated(
584 user_id,
585 vec!["user".to_string()],
586 HashMap::from([(
587 "sub".to_string(),
588 serde_json::Value::String(user_id.to_string()),
589 )]),
590 );
591
592 let result = auth.check_identity_args("list_orders", &serde_json::json!({}), true);
593 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
594 }
595
596 #[test]
597 fn test_identity_args_skip_scope_for_no_input_functions() {
598 let user_id = uuid::Uuid::new_v4();
599 let auth = AuthContext::authenticated(
600 user_id,
601 vec!["user".to_string()],
602 HashMap::from([(
603 "sub".to_string(),
604 serde_json::Value::String(user_id.to_string()),
605 )]),
606 );
607
608 let result = auth.check_identity_args("list_todos", &serde_json::Value::Null, false);
610 assert!(result.is_ok());
611 }
612
613 #[test]
614 fn test_auth_cache_scope_changes_with_claims() {
615 let user_id = uuid::Uuid::new_v4();
616 let auth_a = AuthContext::authenticated(
617 user_id,
618 vec!["user".to_string()],
619 HashMap::from([
620 (
621 "sub".to_string(),
622 serde_json::Value::String(user_id.to_string()),
623 ),
624 (
625 "tenant_id".to_string(),
626 serde_json::Value::String("tenant-a".to_string()),
627 ),
628 ]),
629 );
630 let auth_b = AuthContext::authenticated(
631 user_id,
632 vec!["user".to_string()],
633 HashMap::from([
634 (
635 "sub".to_string(),
636 serde_json::Value::String(user_id.to_string()),
637 ),
638 (
639 "tenant_id".to_string(),
640 serde_json::Value::String("tenant-b".to_string()),
641 ),
642 ]),
643 );
644
645 let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
646 let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
647 assert_ne!(scope_a, scope_b);
648 }
649}