Skip to main content

forge_runtime/function/
router.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use chrono::Utc;
5use forge_core::{
6    AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
7    MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
8    Result, WorkflowDispatch,
9    job::JobStatus,
10    rate_limit::{RateLimitConfig, RateLimitKey},
11    workflow::WorkflowStatus,
12};
13use serde_json::Value;
14
15use super::cache::QueryCache;
16use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
17use crate::db::Database;
18use crate::rate_limit::RateLimiter;
19
20/// Result of routing a function call.
21pub enum RouteResult {
22    /// Query execution result.
23    Query(Value),
24    /// Mutation execution result.
25    Mutation(Value),
26    /// Job dispatch result (returns job_id).
27    Job(Value),
28    /// Workflow dispatch result (returns workflow_id).
29    Workflow(Value),
30}
31
32/// Routes function calls to the appropriate handler.
33pub struct FunctionRouter {
34    registry: Arc<FunctionRegistry>,
35    db: Database,
36    http_client: CircuitBreakerClient,
37    job_dispatcher: Option<Arc<dyn JobDispatch>>,
38    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
39    rate_limiter: RateLimiter,
40    query_cache: QueryCache,
41}
42
43impl FunctionRouter {
44    /// Create a new function router.
45    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
46        let rate_limiter = RateLimiter::new(db.primary().clone());
47        Self {
48            registry,
49            db,
50            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
51            job_dispatcher: None,
52            workflow_dispatcher: None,
53            rate_limiter,
54            query_cache: QueryCache::new(),
55        }
56    }
57
58    /// Create a new function router with a custom HTTP client.
59    pub fn with_http_client(
60        registry: Arc<FunctionRegistry>,
61        db: Database,
62        http_client: CircuitBreakerClient,
63    ) -> Self {
64        let rate_limiter = RateLimiter::new(db.primary().clone());
65        Self {
66            registry,
67            db,
68            http_client,
69            job_dispatcher: None,
70            workflow_dispatcher: None,
71            rate_limiter,
72            query_cache: QueryCache::new(),
73        }
74    }
75
76    /// Set the job dispatcher for this router.
77    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
78        self.job_dispatcher = Some(dispatcher);
79        self
80    }
81
82    /// Set the workflow dispatcher for this router.
83    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
84        self.workflow_dispatcher = Some(dispatcher);
85        self
86    }
87
88    pub async fn route(
89        &self,
90        function_name: &str,
91        args: Value,
92        auth: AuthContext,
93        request: RequestMetadata,
94    ) -> Result<RouteResult> {
95        if let Some(entry) = self.registry.get(function_name) {
96            self.check_auth(entry.info(), &auth)?;
97            self.check_rate_limit(entry.info(), function_name, &auth, &request)
98                .await?;
99
100            return match entry {
101                FunctionEntry::Query { handler, info, .. } => {
102                    if let Some(ttl) = info.cache_ttl {
103                        if let Some(cached) = self.query_cache.get(function_name, &args) {
104                            return Ok(RouteResult::Query(cached));
105                        }
106
107                        // Execute and cache result (use read replica for queries)
108                        let ctx = QueryContext::new(self.db.read_pool().clone(), auth, request);
109                        let result = handler(&ctx, args.clone()).await?;
110
111                        self.query_cache.set(
112                            function_name,
113                            &args,
114                            result.clone(),
115                            Duration::from_secs(ttl),
116                        );
117
118                        Ok(RouteResult::Query(result))
119                    } else {
120                        // Use read replica for queries
121                        let ctx = QueryContext::new(self.db.read_pool().clone(), auth, request);
122                        let result = handler(&ctx, args).await?;
123                        Ok(RouteResult::Query(result))
124                    }
125                }
126                FunctionEntry::Mutation { handler, info } => {
127                    if info.transactional {
128                        self.execute_transactional(handler, args, auth, request)
129                            .await
130                    } else {
131                        // Use primary for mutations
132                        let ctx = MutationContext::with_dispatch(
133                            self.db.primary().clone(),
134                            auth,
135                            request,
136                            self.http_client.clone(),
137                            self.job_dispatcher.clone(),
138                            self.workflow_dispatcher.clone(),
139                        );
140                        let result = handler(&ctx, args).await?;
141                        Ok(RouteResult::Mutation(result))
142                    }
143                }
144            };
145        }
146
147        if let Some(ref job_dispatcher) = self.job_dispatcher {
148            if let Some(job_info) = job_dispatcher.get_info(function_name) {
149                self.check_job_auth(&job_info, &auth)?;
150                match job_dispatcher
151                    .dispatch_by_name(function_name, args.clone())
152                    .await
153                {
154                    Ok(job_id) => {
155                        return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
156                    }
157                    Err(ForgeError::NotFound(_)) => {}
158                    Err(e) => return Err(e),
159                }
160            }
161        }
162
163        if let Some(ref workflow_dispatcher) = self.workflow_dispatcher {
164            if let Some(workflow_info) = workflow_dispatcher.get_info(function_name) {
165                self.check_workflow_auth(&workflow_info, &auth)?;
166                match workflow_dispatcher
167                    .start_by_name(function_name, args.clone())
168                    .await
169                {
170                    Ok(workflow_id) => {
171                        return Ok(RouteResult::Workflow(
172                            serde_json::json!({ "workflow_id": workflow_id }),
173                        ));
174                    }
175                    Err(ForgeError::NotFound(_)) => {}
176                    Err(e) => return Err(e),
177                }
178            }
179        }
180
181        Err(ForgeError::NotFound(format!(
182            "Function '{}' not found",
183            function_name
184        )))
185    }
186
187    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
188        if info.is_public {
189            return Ok(());
190        }
191
192        if !auth.is_authenticated() {
193            return Err(ForgeError::Unauthorized("Authentication required".into()));
194        }
195
196        if let Some(role) = info.required_role {
197            if !auth.has_role(role) {
198                return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
199            }
200        }
201
202        Ok(())
203    }
204
205    fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
206        if info.is_public {
207            return Ok(());
208        }
209
210        if !auth.is_authenticated() {
211            return Err(ForgeError::Unauthorized("Authentication required".into()));
212        }
213
214        if let Some(role) = info.required_role {
215            if !auth.has_role(role) {
216                return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
217            }
218        }
219
220        Ok(())
221    }
222
223    fn check_workflow_auth(
224        &self,
225        info: &forge_core::workflow::WorkflowInfo,
226        auth: &AuthContext,
227    ) -> Result<()> {
228        if info.is_public {
229            return Ok(());
230        }
231
232        if !auth.is_authenticated() {
233            return Err(ForgeError::Unauthorized("Authentication required".into()));
234        }
235
236        if let Some(role) = info.required_role {
237            if !auth.has_role(role) {
238                return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
239            }
240        }
241
242        Ok(())
243    }
244
245    /// Check rate limit for a function call.
246    async fn check_rate_limit(
247        &self,
248        info: &FunctionInfo,
249        function_name: &str,
250        auth: &AuthContext,
251        request: &RequestMetadata,
252    ) -> Result<()> {
253        // Skip if no rate limit configured
254        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
255            (Some(r), Some(p)) => (r, p),
256            _ => return Ok(()),
257        };
258
259        // Build rate limit config
260        let key_str = info.rate_limit_key.unwrap_or("user");
261        let key_type: RateLimitKey = match key_str.parse() {
262            Ok(k) => k,
263            Err(_) => {
264                tracing::warn!(
265                    function = %function_name,
266                    key = %key_str,
267                    "Invalid rate limit key, falling back to 'user'"
268                );
269                RateLimitKey::default()
270            }
271        };
272
273        let config =
274            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
275
276        // Build bucket key
277        let bucket_key = self
278            .rate_limiter
279            .build_key(key_type, function_name, auth, request);
280
281        // Enforce rate limit
282        self.rate_limiter.enforce(&bucket_key, &config).await?;
283
284        Ok(())
285    }
286
287    /// Get the function kind by name.
288    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
289        self.registry.get(function_name).map(|e| e.kind())
290    }
291
292    /// Check if a function exists.
293    pub fn has_function(&self, function_name: &str) -> bool {
294        self.registry.get(function_name).is_some()
295    }
296
297    async fn execute_transactional(
298        &self,
299        handler: &BoxedMutationFn,
300        args: Value,
301        auth: AuthContext,
302        request: RequestMetadata,
303    ) -> Result<RouteResult> {
304        // Use primary for transactional mutations
305        let primary = self.db.primary();
306        let tx = primary
307            .begin()
308            .await
309            .map_err(|e| ForgeError::Database(e.to_string()))?;
310
311        let job_dispatcher = self.job_dispatcher.clone();
312        let job_lookup: forge_core::JobInfoLookup =
313            Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
314
315        let (ctx, tx_handle, outbox) = MutationContext::with_transaction(
316            primary.clone(),
317            tx,
318            auth,
319            request,
320            self.http_client.clone(),
321            job_lookup,
322        );
323
324        match handler(&ctx, args).await {
325            Ok(value) => {
326                let buffer = {
327                    let guard = outbox.lock().unwrap();
328                    OutboxBuffer {
329                        jobs: guard.jobs.clone(),
330                        workflows: guard.workflows.clone(),
331                    }
332                };
333
334                let mut tx = Arc::try_unwrap(tx_handle)
335                    .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
336                    .into_inner();
337
338                for job in &buffer.jobs {
339                    Self::insert_job(&mut tx, job).await?;
340                }
341
342                for workflow in &buffer.workflows {
343                    Self::insert_workflow(&mut tx, workflow).await?;
344                }
345
346                tx.commit()
347                    .await
348                    .map_err(|e| ForgeError::Database(e.to_string()))?;
349
350                Ok(RouteResult::Mutation(value))
351            }
352            Err(e) => Err(e),
353        }
354    }
355
356    async fn insert_job(
357        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
358        job: &PendingJob,
359    ) -> Result<()> {
360        let now = Utc::now();
361        sqlx::query(
362            r#"
363            INSERT INTO forge_jobs (
364                id, job_type, input, job_context, status, priority, attempts, max_attempts,
365                worker_capability, scheduled_at, created_at
366            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
367            "#,
368        )
369        .bind(job.id)
370        .bind(&job.job_type)
371        .bind(&job.args)
372        .bind(&job.context)
373        .bind(JobStatus::Pending.as_str())
374        .bind(job.priority)
375        .bind(0i32)
376        .bind(job.max_attempts)
377        .bind(&job.worker_capability)
378        .bind(now)
379        .bind(now)
380        .execute(&mut **tx)
381        .await
382        .map_err(|e| ForgeError::Database(e.to_string()))?;
383
384        Ok(())
385    }
386
387    async fn insert_workflow(
388        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
389        workflow: &PendingWorkflow,
390    ) -> Result<()> {
391        let now = Utc::now();
392        sqlx::query(
393            r#"
394            INSERT INTO forge_workflow_runs (
395                id, workflow_name, input, status, current_step,
396                step_results, started_at, trace_id
397            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
398            "#,
399        )
400        .bind(workflow.id)
401        .bind(&workflow.workflow_name)
402        .bind(&workflow.input)
403        .bind(WorkflowStatus::Created.as_str())
404        .bind(Option::<String>::None)
405        .bind(serde_json::json!({}))
406        .bind(now)
407        .bind(workflow.id.to_string())
408        .execute(&mut **tx)
409        .await
410        .map_err(|e| ForgeError::Database(e.to_string()))?;
411
412        Ok(())
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_check_auth_public() {
422        let info = FunctionInfo {
423            name: "test",
424            description: None,
425            kind: FunctionKind::Query,
426            required_role: None,
427            is_public: true,
428            cache_ttl: None,
429            timeout: None,
430            rate_limit_requests: None,
431            rate_limit_per_secs: None,
432            rate_limit_key: None,
433            log_level: None,
434            table_dependencies: &[],
435            transactional: false,
436        };
437
438        let _auth = AuthContext::unauthenticated();
439
440        // Can't test check_auth directly without a router instance,
441        // but we can test the logic
442        assert!(info.is_public);
443    }
444}