1use std::sync::Arc;
2use std::time::Duration;
3
4use chrono::Utc;
5use forge_core::{
6 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
7 MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
8 Result, WorkflowDispatch,
9 job::JobStatus,
10 rate_limit::{RateLimitConfig, RateLimitKey},
11 workflow::WorkflowStatus,
12};
13use serde_json::Value;
14
15use super::cache::QueryCache;
16use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
17use crate::db::Database;
18use crate::rate_limit::RateLimiter;
19
20pub enum RouteResult {
22 Query(Value),
24 Mutation(Value),
26 Job(Value),
28 Workflow(Value),
30}
31
32pub struct FunctionRouter {
34 registry: Arc<FunctionRegistry>,
35 db: Database,
36 http_client: CircuitBreakerClient,
37 job_dispatcher: Option<Arc<dyn JobDispatch>>,
38 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
39 rate_limiter: RateLimiter,
40 query_cache: QueryCache,
41}
42
43impl FunctionRouter {
44 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
46 let rate_limiter = RateLimiter::new(db.primary().clone());
47 Self {
48 registry,
49 db,
50 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
51 job_dispatcher: None,
52 workflow_dispatcher: None,
53 rate_limiter,
54 query_cache: QueryCache::new(),
55 }
56 }
57
58 pub fn with_http_client(
60 registry: Arc<FunctionRegistry>,
61 db: Database,
62 http_client: CircuitBreakerClient,
63 ) -> Self {
64 let rate_limiter = RateLimiter::new(db.primary().clone());
65 Self {
66 registry,
67 db,
68 http_client,
69 job_dispatcher: None,
70 workflow_dispatcher: None,
71 rate_limiter,
72 query_cache: QueryCache::new(),
73 }
74 }
75
76 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
78 self.job_dispatcher = Some(dispatcher);
79 self
80 }
81
82 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
84 self.workflow_dispatcher = Some(dispatcher);
85 self
86 }
87
88 pub async fn route(
89 &self,
90 function_name: &str,
91 args: Value,
92 auth: AuthContext,
93 request: RequestMetadata,
94 ) -> Result<RouteResult> {
95 if let Some(entry) = self.registry.get(function_name) {
96 self.check_auth(entry.info(), &auth)?;
97 self.check_rate_limit(entry.info(), function_name, &auth, &request)
98 .await?;
99
100 return match entry {
101 FunctionEntry::Query { handler, info, .. } => {
102 if let Some(ttl) = info.cache_ttl {
103 if let Some(cached) = self.query_cache.get(function_name, &args) {
104 return Ok(RouteResult::Query(cached));
105 }
106
107 let ctx = QueryContext::new(self.db.read_pool().clone(), auth, request);
109 let result = handler(&ctx, args.clone()).await?;
110
111 self.query_cache.set(
112 function_name,
113 &args,
114 result.clone(),
115 Duration::from_secs(ttl),
116 );
117
118 Ok(RouteResult::Query(result))
119 } else {
120 let ctx = QueryContext::new(self.db.read_pool().clone(), auth, request);
122 let result = handler(&ctx, args).await?;
123 Ok(RouteResult::Query(result))
124 }
125 }
126 FunctionEntry::Mutation { handler, info } => {
127 if info.transactional {
128 self.execute_transactional(handler, args, auth, request)
129 .await
130 } else {
131 let ctx = MutationContext::with_dispatch(
133 self.db.primary().clone(),
134 auth,
135 request,
136 self.http_client.clone(),
137 self.job_dispatcher.clone(),
138 self.workflow_dispatcher.clone(),
139 );
140 let result = handler(&ctx, args).await?;
141 Ok(RouteResult::Mutation(result))
142 }
143 }
144 };
145 }
146
147 if let Some(ref job_dispatcher) = self.job_dispatcher {
148 if let Some(job_info) = job_dispatcher.get_info(function_name) {
149 self.check_job_auth(&job_info, &auth)?;
150 match job_dispatcher
151 .dispatch_by_name(function_name, args.clone())
152 .await
153 {
154 Ok(job_id) => {
155 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
156 }
157 Err(ForgeError::NotFound(_)) => {}
158 Err(e) => return Err(e),
159 }
160 }
161 }
162
163 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher {
164 if let Some(workflow_info) = workflow_dispatcher.get_info(function_name) {
165 self.check_workflow_auth(&workflow_info, &auth)?;
166 match workflow_dispatcher
167 .start_by_name(function_name, args.clone())
168 .await
169 {
170 Ok(workflow_id) => {
171 return Ok(RouteResult::Workflow(
172 serde_json::json!({ "workflow_id": workflow_id }),
173 ));
174 }
175 Err(ForgeError::NotFound(_)) => {}
176 Err(e) => return Err(e),
177 }
178 }
179 }
180
181 Err(ForgeError::NotFound(format!(
182 "Function '{}' not found",
183 function_name
184 )))
185 }
186
187 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
188 if info.is_public {
189 return Ok(());
190 }
191
192 if !auth.is_authenticated() {
193 return Err(ForgeError::Unauthorized("Authentication required".into()));
194 }
195
196 if let Some(role) = info.required_role {
197 if !auth.has_role(role) {
198 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
199 }
200 }
201
202 Ok(())
203 }
204
205 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
206 if info.is_public {
207 return Ok(());
208 }
209
210 if !auth.is_authenticated() {
211 return Err(ForgeError::Unauthorized("Authentication required".into()));
212 }
213
214 if let Some(role) = info.required_role {
215 if !auth.has_role(role) {
216 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
217 }
218 }
219
220 Ok(())
221 }
222
223 fn check_workflow_auth(
224 &self,
225 info: &forge_core::workflow::WorkflowInfo,
226 auth: &AuthContext,
227 ) -> Result<()> {
228 if info.is_public {
229 return Ok(());
230 }
231
232 if !auth.is_authenticated() {
233 return Err(ForgeError::Unauthorized("Authentication required".into()));
234 }
235
236 if let Some(role) = info.required_role {
237 if !auth.has_role(role) {
238 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
239 }
240 }
241
242 Ok(())
243 }
244
245 async fn check_rate_limit(
247 &self,
248 info: &FunctionInfo,
249 function_name: &str,
250 auth: &AuthContext,
251 request: &RequestMetadata,
252 ) -> Result<()> {
253 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
255 (Some(r), Some(p)) => (r, p),
256 _ => return Ok(()),
257 };
258
259 let key_str = info.rate_limit_key.unwrap_or("user");
261 let key_type: RateLimitKey = match key_str.parse() {
262 Ok(k) => k,
263 Err(_) => {
264 tracing::warn!(
265 function = %function_name,
266 key = %key_str,
267 "Invalid rate limit key, falling back to 'user'"
268 );
269 RateLimitKey::default()
270 }
271 };
272
273 let config =
274 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
275
276 let bucket_key = self
278 .rate_limiter
279 .build_key(key_type, function_name, auth, request);
280
281 self.rate_limiter.enforce(&bucket_key, &config).await?;
283
284 Ok(())
285 }
286
287 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
289 self.registry.get(function_name).map(|e| e.kind())
290 }
291
292 pub fn has_function(&self, function_name: &str) -> bool {
294 self.registry.get(function_name).is_some()
295 }
296
297 async fn execute_transactional(
298 &self,
299 handler: &BoxedMutationFn,
300 args: Value,
301 auth: AuthContext,
302 request: RequestMetadata,
303 ) -> Result<RouteResult> {
304 let primary = self.db.primary();
306 let tx = primary
307 .begin()
308 .await
309 .map_err(|e| ForgeError::Database(e.to_string()))?;
310
311 let job_dispatcher = self.job_dispatcher.clone();
312 let job_lookup: forge_core::JobInfoLookup =
313 Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
314
315 let (ctx, tx_handle, outbox) = MutationContext::with_transaction(
316 primary.clone(),
317 tx,
318 auth,
319 request,
320 self.http_client.clone(),
321 job_lookup,
322 );
323
324 match handler(&ctx, args).await {
325 Ok(value) => {
326 let buffer = {
327 let guard = outbox.lock().unwrap();
328 OutboxBuffer {
329 jobs: guard.jobs.clone(),
330 workflows: guard.workflows.clone(),
331 }
332 };
333
334 let mut tx = Arc::try_unwrap(tx_handle)
335 .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
336 .into_inner();
337
338 for job in &buffer.jobs {
339 Self::insert_job(&mut tx, job).await?;
340 }
341
342 for workflow in &buffer.workflows {
343 Self::insert_workflow(&mut tx, workflow).await?;
344 }
345
346 tx.commit()
347 .await
348 .map_err(|e| ForgeError::Database(e.to_string()))?;
349
350 Ok(RouteResult::Mutation(value))
351 }
352 Err(e) => Err(e),
353 }
354 }
355
356 async fn insert_job(
357 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
358 job: &PendingJob,
359 ) -> Result<()> {
360 let now = Utc::now();
361 sqlx::query(
362 r#"
363 INSERT INTO forge_jobs (
364 id, job_type, input, job_context, status, priority, attempts, max_attempts,
365 worker_capability, scheduled_at, created_at
366 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
367 "#,
368 )
369 .bind(job.id)
370 .bind(&job.job_type)
371 .bind(&job.args)
372 .bind(&job.context)
373 .bind(JobStatus::Pending.as_str())
374 .bind(job.priority)
375 .bind(0i32)
376 .bind(job.max_attempts)
377 .bind(&job.worker_capability)
378 .bind(now)
379 .bind(now)
380 .execute(&mut **tx)
381 .await
382 .map_err(|e| ForgeError::Database(e.to_string()))?;
383
384 Ok(())
385 }
386
387 async fn insert_workflow(
388 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
389 workflow: &PendingWorkflow,
390 ) -> Result<()> {
391 let now = Utc::now();
392 sqlx::query(
393 r#"
394 INSERT INTO forge_workflow_runs (
395 id, workflow_name, input, status, current_step,
396 step_results, started_at, trace_id
397 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
398 "#,
399 )
400 .bind(workflow.id)
401 .bind(&workflow.workflow_name)
402 .bind(&workflow.input)
403 .bind(WorkflowStatus::Created.as_str())
404 .bind(Option::<String>::None)
405 .bind(serde_json::json!({}))
406 .bind(now)
407 .bind(workflow.id.to_string())
408 .execute(&mut **tx)
409 .await
410 .map_err(|e| ForgeError::Database(e.to_string()))?;
411
412 Ok(())
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn test_check_auth_public() {
422 let info = FunctionInfo {
423 name: "test",
424 description: None,
425 kind: FunctionKind::Query,
426 required_role: None,
427 is_public: true,
428 cache_ttl: None,
429 timeout: None,
430 rate_limit_requests: None,
431 rate_limit_per_secs: None,
432 rate_limit_key: None,
433 log_level: None,
434 table_dependencies: &[],
435 transactional: false,
436 };
437
438 let _auth = AuthContext::unauthenticated();
439
440 assert!(info.is_public);
443 }
444}