1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8 MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9 Result, WorkflowDispatch,
10 job::JobStatus,
11 rate_limit::{RateLimitConfig, RateLimitKey},
12 workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22fn require_auth(is_public: bool, required_role: Option<&str>, auth: &AuthContext) -> Result<()> {
24 if is_public {
25 return Ok(());
26 }
27 if !auth.is_authenticated() {
28 return Err(ForgeError::Unauthorized("Authentication required".into()));
29 }
30 if let Some(role) = required_role
31 && !auth.has_role(role)
32 {
33 return Err(ForgeError::Forbidden(format!("Role '{role}' required")));
34 }
35 Ok(())
36}
37
38pub enum RouteResult {
40 Query(Value),
42 Mutation(Value),
44 Job(Value),
46 Workflow(Value),
48}
49
50pub struct FunctionRouter {
52 registry: Arc<FunctionRegistry>,
53 db: Database,
54 http_client: CircuitBreakerClient,
55 job_dispatcher: Option<Arc<dyn JobDispatch>>,
56 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
57 rate_limiter: HybridRateLimiter,
58 query_cache: QueryCache,
59 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
60 token_ttl: forge_core::AuthTokenTtl,
61}
62
63impl FunctionRouter {
64 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
66 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
67 Self {
68 registry,
69 db,
70 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
71 job_dispatcher: None,
72 workflow_dispatcher: None,
73 rate_limiter,
74 query_cache: QueryCache::new(),
75 token_issuer: None,
76 token_ttl: forge_core::AuthTokenTtl::default(),
77 }
78 }
79
80 pub fn with_http_client(
82 registry: Arc<FunctionRegistry>,
83 db: Database,
84 http_client: CircuitBreakerClient,
85 ) -> Self {
86 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
87 Self {
88 registry,
89 db,
90 http_client,
91 job_dispatcher: None,
92 workflow_dispatcher: None,
93 rate_limiter,
94 query_cache: QueryCache::new(),
95 token_issuer: None,
96 token_ttl: forge_core::AuthTokenTtl::default(),
97 }
98 }
99
100 pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
102 self.token_issuer = Some(issuer);
103 self
104 }
105
106 pub fn with_token_ttl(mut self, ttl: forge_core::AuthTokenTtl) -> Self {
108 self.token_ttl = ttl;
109 self
110 }
111
112 pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
114 self.token_ttl = ttl;
115 }
116
117 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
119 self.job_dispatcher = Some(dispatcher);
120 self
121 }
122
123 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
125 self.workflow_dispatcher = Some(dispatcher);
126 self
127 }
128
129 pub async fn route(
130 &self,
131 function_name: &str,
132 args: Value,
133 auth: AuthContext,
134 request: RequestMetadata,
135 ) -> Result<RouteResult> {
136 if let Some(entry) = self.registry.get(function_name) {
137 self.check_auth(entry.info(), &auth)?;
138 self.check_rate_limit(entry.info(), function_name, &auth, &request)
139 .await?;
140
141 return match entry {
142 FunctionEntry::Query { handler, info, .. } => {
143 let pool = if info.consistent {
144 self.db.primary().clone()
145 } else {
146 self.db.read_pool().clone()
147 };
148
149 let auth_scope = Self::auth_cache_scope(&auth);
150 if let Some(ttl) = info.cache_ttl {
151 if let Some(cached) =
152 self.query_cache
153 .get(function_name, &args, auth_scope.as_deref())
154 {
155 return Ok(RouteResult::Query(Value::clone(&cached)));
156 }
157
158 let ctx = QueryContext::new(pool, auth, request);
159 let result = handler(&ctx, args.clone()).await?;
160
161 self.query_cache.set(
162 function_name,
163 &args,
164 auth_scope.as_deref(),
165 result.clone(),
166 Duration::from_secs(ttl),
167 );
168
169 Ok(RouteResult::Query(result))
170 } else {
171 let ctx = QueryContext::new(pool, auth, request);
172 let result = handler(&ctx, args).await?;
173 Ok(RouteResult::Query(result))
174 }
175 }
176 FunctionEntry::Mutation { handler, info } => {
177 if info.transactional {
178 self.execute_transactional(info, handler, args, auth, request)
179 .await
180 } else {
181 let mut ctx = MutationContext::with_dispatch(
183 self.db.primary().clone(),
184 auth,
185 request,
186 self.http_client.clone(),
187 self.job_dispatcher.clone(),
188 self.workflow_dispatcher.clone(),
189 );
190 if let Some(ref issuer) = self.token_issuer {
191 ctx.set_token_issuer(issuer.clone());
192 }
193 ctx.set_token_ttl(self.token_ttl.clone());
194 ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
195 let result = handler(&ctx, args).await?;
196 Ok(RouteResult::Mutation(result))
197 }
198 }
199 };
200 }
201
202 if let Some(ref job_dispatcher) = self.job_dispatcher
203 && let Some(job_info) = job_dispatcher.get_info(function_name)
204 {
205 self.check_job_auth(&job_info, &auth)?;
206 match job_dispatcher
207 .dispatch_by_name(function_name, args.clone(), auth.principal_id())
208 .await
209 {
210 Ok(job_id) => {
211 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
212 }
213 Err(ForgeError::NotFound(_)) => {}
214 Err(e) => return Err(e),
215 }
216 }
217
218 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
219 && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
220 {
221 self.check_workflow_auth(&workflow_info, &auth)?;
222 match workflow_dispatcher
223 .start_by_name(function_name, args.clone(), auth.principal_id())
224 .await
225 {
226 Ok(workflow_id) => {
227 return Ok(RouteResult::Workflow(
228 serde_json::json!({ "workflow_id": workflow_id }),
229 ));
230 }
231 Err(ForgeError::NotFound(_)) => {}
232 Err(e) => return Err(e),
233 }
234 }
235
236 Err(ForgeError::NotFound(format!(
237 "Function '{}' not found",
238 function_name
239 )))
240 }
241
242 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
243 require_auth(info.is_public, info.required_role, auth)
244 }
245
246 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
247 require_auth(info.is_public, info.required_role, auth)
248 }
249
250 fn check_workflow_auth(
251 &self,
252 info: &forge_core::workflow::WorkflowInfo,
253 auth: &AuthContext,
254 ) -> Result<()> {
255 require_auth(info.is_public, info.required_role, auth)
256 }
257
258 async fn check_rate_limit(
260 &self,
261 info: &FunctionInfo,
262 function_name: &str,
263 auth: &AuthContext,
264 request: &RequestMetadata,
265 ) -> Result<()> {
266 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
268 (Some(r), Some(p)) => (r, p),
269 _ => return Ok(()),
270 };
271
272 let key_str = info.rate_limit_key.unwrap_or("user");
274 let key_type: RateLimitKey = match key_str.parse() {
275 Ok(k) => k,
276 Err(_) => {
277 tracing::error!(
278 function = %function_name,
279 key = %key_str,
280 "Invalid rate limit key, falling back to 'user'"
281 );
282 RateLimitKey::default()
283 }
284 };
285
286 let config =
287 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
288
289 let bucket_key = self
291 .rate_limiter
292 .build_key(key_type, function_name, auth, request);
293
294 self.rate_limiter.enforce(&bucket_key, &config).await?;
296
297 Ok(())
298 }
299
300 fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
301 if !auth.is_authenticated() {
302 return Some("anon".to_string());
303 }
304
305 let mut roles = auth.roles().to_vec();
307 roles.sort();
308 roles.dedup();
309
310 let mut claims = BTreeMap::new();
311 for (k, v) in auth.claims() {
312 claims.insert(k.clone(), v.clone());
313 }
314
315 use std::collections::hash_map::DefaultHasher;
316 use std::hash::{Hash, Hasher};
317
318 let mut hasher = DefaultHasher::new();
319 roles.hash(&mut hasher);
320 serde_json::to_string(&claims)
321 .unwrap_or_default()
322 .hash(&mut hasher);
323
324 let principal = auth
325 .principal_id()
326 .unwrap_or_else(|| "authenticated".to_string());
327
328 Some(format!(
329 "subject:{principal}:scope:{:016x}",
330 hasher.finish()
331 ))
332 }
333
334 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
336 self.registry.get(function_name).map(|e| e.kind())
337 }
338
339 pub fn has_function(&self, function_name: &str) -> bool {
341 self.registry.get(function_name).is_some()
342 }
343
344 async fn execute_transactional(
345 &self,
346 info: &FunctionInfo,
347 handler: &BoxedMutationFn,
348 args: Value,
349 auth: AuthContext,
350 request: RequestMetadata,
351 ) -> Result<RouteResult> {
352 let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
353
354 async {
355 let primary = self.db.primary();
356 let tx = primary
357 .begin()
358 .await
359 .map_err(|e| ForgeError::Database(e.to_string()))?;
360
361 let job_dispatcher = self.job_dispatcher.clone();
362 let job_lookup: forge_core::JobInfoLookup =
363 Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
364
365 let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
366 primary.clone(),
367 tx,
368 auth,
369 request,
370 self.http_client.clone(),
371 job_lookup,
372 );
373 if let Some(ref issuer) = self.token_issuer {
374 ctx.set_token_issuer(issuer.clone());
375 }
376 ctx.set_token_ttl(self.token_ttl.clone());
377 ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
378
379 match handler(&ctx, args).await {
380 Ok(value) => {
381 let buffer = {
382 let guard = outbox.lock().unwrap_or_else(|poisoned| {
383 tracing::error!("Outbox mutex was poisoned, recovering");
384 poisoned.into_inner()
385 });
386 OutboxBuffer {
387 jobs: guard.jobs.clone(),
388 workflows: guard.workflows.clone(),
389 }
390 };
391
392 let mut tx = Arc::try_unwrap(tx_handle)
393 .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
394 .into_inner();
395
396 for job in &buffer.jobs {
397 Self::insert_job(&mut tx, job).await?;
398 }
399
400 for workflow in &buffer.workflows {
401 if self
402 .workflow_dispatcher
403 .as_ref()
404 .and_then(|d| d.get_info(&workflow.workflow_name))
405 .is_none()
406 {
407 return Err(ForgeError::NotFound(format!(
408 "Workflow '{}' not found",
409 workflow.workflow_name
410 )));
411 }
412 Self::insert_workflow(&mut tx, workflow).await?;
413 }
414
415 tx.commit()
416 .await
417 .map_err(|e| ForgeError::Database(e.to_string()))?;
418
419 Ok(RouteResult::Mutation(value))
420 }
421 Err(e) => Err(e),
422 }
423 }
424 .instrument(span)
425 .await
426 }
427
428 async fn insert_job(
429 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
430 job: &PendingJob,
431 ) -> Result<()> {
432 let now = Utc::now();
433 sqlx::query!(
434 r#"
435 INSERT INTO forge_jobs (
436 id, job_type, input, job_context, status, priority, attempts, max_attempts,
437 worker_capability, owner_subject, scheduled_at, created_at
438 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
439 "#,
440 job.id,
441 &job.job_type,
442 job.args as _,
443 job.context as _,
444 JobStatus::Pending.as_str(),
445 job.priority,
446 0i32,
447 job.max_attempts,
448 job.worker_capability.as_deref(),
449 job.owner_subject as _,
450 now,
451 now,
452 )
453 .execute(&mut **tx)
454 .await
455 .map_err(|e| ForgeError::Database(e.to_string()))?;
456
457 Ok(())
458 }
459
460 async fn insert_workflow(
461 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
462 workflow: &PendingWorkflow,
463 ) -> Result<()> {
464 let now = Utc::now();
465 sqlx::query!(
466 r#"
467 INSERT INTO forge_workflow_runs (
468 id, workflow_name, owner_subject, input, status, current_step,
469 step_results, started_at, trace_id
470 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
471 "#,
472 workflow.id,
473 &workflow.workflow_name,
474 workflow.owner_subject as _,
475 workflow.input as _,
476 WorkflowStatus::Created.as_str(),
477 Option::<String>::None,
478 serde_json::json!({}) as _,
479 now,
480 workflow.id.to_string(),
481 )
482 .execute(&mut **tx)
483 .await
484 .map_err(|e| ForgeError::Database(e.to_string()))?;
485
486 Ok(())
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use std::collections::HashMap;
494
495 #[test]
496 fn test_check_auth_public() {
497 let info = FunctionInfo {
498 name: "test",
499 description: None,
500 kind: FunctionKind::Query,
501 required_role: None,
502 is_public: true,
503 cache_ttl: None,
504 timeout: None,
505 http_timeout: None,
506 rate_limit_requests: None,
507 rate_limit_per_secs: None,
508 rate_limit_key: None,
509 log_level: None,
510 table_dependencies: &[],
511 selected_columns: &[],
512 transactional: false,
513 consistent: false,
514 max_upload_size_bytes: None,
515 };
516
517 let _auth = AuthContext::unauthenticated();
518
519 assert!(info.is_public);
522 }
523
524 #[test]
525 fn test_auth_cache_scope_changes_with_claims() {
526 let user_id = uuid::Uuid::new_v4();
527 let auth_a = AuthContext::authenticated(
528 user_id,
529 vec!["user".to_string()],
530 HashMap::from([
531 (
532 "sub".to_string(),
533 serde_json::Value::String(user_id.to_string()),
534 ),
535 (
536 "tenant_id".to_string(),
537 serde_json::Value::String("tenant-a".to_string()),
538 ),
539 ]),
540 );
541 let auth_b = AuthContext::authenticated(
542 user_id,
543 vec!["user".to_string()],
544 HashMap::from([
545 (
546 "sub".to_string(),
547 serde_json::Value::String(user_id.to_string()),
548 ),
549 (
550 "tenant_id".to_string(),
551 serde_json::Value::String("tenant-b".to_string()),
552 ),
553 ]),
554 );
555
556 let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
557 let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
558 assert_ne!(scope_a, scope_b);
559 }
560}