forge_runtime/function/
router.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{
5    AuthContext, ForgeError, FunctionInfo, FunctionKind, JobDispatch, MutationContext,
6    QueryContext, RequestMetadata, Result, WorkflowDispatch,
7    rate_limit::{RateLimitConfig, RateLimitKey},
8};
9use serde_json::Value;
10
11use super::cache::QueryCache;
12use super::registry::{FunctionEntry, FunctionRegistry};
13use crate::rate_limit::RateLimiter;
14
15/// Result of routing a function call.
16pub enum RouteResult {
17    /// Query execution result.
18    Query(Value),
19    /// Mutation execution result.
20    Mutation(Value),
21    /// Job dispatch result (returns job_id).
22    Job(Value),
23    /// Workflow dispatch result (returns workflow_id).
24    Workflow(Value),
25}
26
27/// Routes function calls to the appropriate handler.
28pub struct FunctionRouter {
29    registry: Arc<FunctionRegistry>,
30    db_pool: sqlx::PgPool,
31    http_client: reqwest::Client,
32    job_dispatcher: Option<Arc<dyn JobDispatch>>,
33    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
34    rate_limiter: RateLimiter,
35    query_cache: QueryCache,
36}
37
38impl FunctionRouter {
39    /// Create a new function router.
40    pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
41        let rate_limiter = RateLimiter::new(db_pool.clone());
42        Self {
43            registry,
44            db_pool,
45            http_client: reqwest::Client::new(),
46            job_dispatcher: None,
47            workflow_dispatcher: None,
48            rate_limiter,
49            query_cache: QueryCache::new(),
50        }
51    }
52
53    /// Create a new function router with a custom HTTP client.
54    pub fn with_http_client(
55        registry: Arc<FunctionRegistry>,
56        db_pool: sqlx::PgPool,
57        http_client: reqwest::Client,
58    ) -> Self {
59        let rate_limiter = RateLimiter::new(db_pool.clone());
60        Self {
61            registry,
62            db_pool,
63            http_client,
64            job_dispatcher: None,
65            workflow_dispatcher: None,
66            rate_limiter,
67            query_cache: QueryCache::new(),
68        }
69    }
70
71    /// Set the job dispatcher for this router.
72    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
73        self.job_dispatcher = Some(dispatcher);
74        self
75    }
76
77    /// Set the workflow dispatcher for this router.
78    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
79        self.workflow_dispatcher = Some(dispatcher);
80        self
81    }
82
83    /// Route and execute a function call.
84    pub async fn route(
85        &self,
86        function_name: &str,
87        args: Value,
88        auth: AuthContext,
89        request: RequestMetadata,
90    ) -> Result<RouteResult> {
91        // First, try to find in the function registry (queries/mutations)
92        if let Some(entry) = self.registry.get(function_name) {
93            // Check authorization
94            self.check_auth(entry.info(), &auth)?;
95
96            // Check rate limit
97            self.check_rate_limit(entry.info(), function_name, &auth, &request)
98                .await?;
99
100            return match entry {
101                FunctionEntry::Query { handler, info, .. } => {
102                    if let Some(ttl) = info.cache_ttl {
103                        if let Some(cached) = self.query_cache.get(function_name, &args) {
104                            return Ok(RouteResult::Query(cached));
105                        }
106
107                        // Execute and cache result
108                        let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
109                        let result = handler(&ctx, args.clone()).await?;
110
111                        self.query_cache.set(
112                            function_name,
113                            &args,
114                            result.clone(),
115                            Duration::from_secs(ttl),
116                        );
117
118                        Ok(RouteResult::Query(result))
119                    } else {
120                        let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
121                        let result = handler(&ctx, args).await?;
122                        Ok(RouteResult::Query(result))
123                    }
124                }
125                FunctionEntry::Mutation { handler, .. } => {
126                    let ctx = MutationContext::with_dispatch(
127                        self.db_pool.clone(),
128                        auth,
129                        request,
130                        self.http_client.clone(),
131                        self.job_dispatcher.clone(),
132                        self.workflow_dispatcher.clone(),
133                    );
134                    let result = handler(&ctx, args).await?;
135                    Ok(RouteResult::Mutation(result))
136                }
137            };
138        }
139
140        // Try job dispatcher - check auth using job info
141        if let Some(ref job_dispatcher) = self.job_dispatcher {
142            if let Some(job_info) = job_dispatcher.get_info(function_name) {
143                self.check_job_auth(&job_info, &auth)?;
144                match job_dispatcher
145                    .dispatch_by_name(function_name, args.clone())
146                    .await
147                {
148                    Ok(job_id) => {
149                        return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
150                    }
151                    Err(ForgeError::NotFound(_)) => {}
152                    Err(e) => return Err(e),
153                }
154            }
155        }
156
157        // Try workflow dispatcher - check auth using workflow info
158        if let Some(ref workflow_dispatcher) = self.workflow_dispatcher {
159            if let Some(workflow_info) = workflow_dispatcher.get_info(function_name) {
160                self.check_workflow_auth(&workflow_info, &auth)?;
161                match workflow_dispatcher
162                    .start_by_name(function_name, args.clone())
163                    .await
164                {
165                    Ok(workflow_id) => {
166                        return Ok(RouteResult::Workflow(
167                            serde_json::json!({ "workflow_id": workflow_id }),
168                        ));
169                    }
170                    Err(ForgeError::NotFound(_)) => {}
171                    Err(e) => return Err(e),
172                }
173            }
174        }
175
176        // Nothing found
177        Err(ForgeError::NotFound(format!(
178            "Function '{}' not found",
179            function_name
180        )))
181    }
182
183    /// Check authorization for a function call.
184    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
185        // Public functions don't require auth
186        if info.is_public {
187            return Ok(());
188        }
189
190        // Check role requirement (implies auth required)
191        if let Some(role) = info.required_role {
192            if !auth.is_authenticated() {
193                return Err(ForgeError::Unauthorized("Authentication required".into()));
194            }
195            if !auth.has_role(role) {
196                return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
197            }
198        }
199
200        Ok(())
201    }
202
203    /// Check authorization for a job dispatch.
204    fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
205        if info.is_public {
206            return Ok(());
207        }
208
209        if let Some(role) = info.required_role {
210            if !auth.is_authenticated() {
211                return Err(ForgeError::Unauthorized("Authentication required".into()));
212            }
213            if !auth.has_role(role) {
214                return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
215            }
216        }
217
218        Ok(())
219    }
220
221    /// Check authorization for a workflow dispatch.
222    fn check_workflow_auth(
223        &self,
224        info: &forge_core::workflow::WorkflowInfo,
225        auth: &AuthContext,
226    ) -> Result<()> {
227        if info.is_public {
228            return Ok(());
229        }
230
231        if let Some(role) = info.required_role {
232            if !auth.is_authenticated() {
233                return Err(ForgeError::Unauthorized("Authentication required".into()));
234            }
235            if !auth.has_role(role) {
236                return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
237            }
238        }
239
240        Ok(())
241    }
242
243    /// Check rate limit for a function call.
244    async fn check_rate_limit(
245        &self,
246        info: &FunctionInfo,
247        function_name: &str,
248        auth: &AuthContext,
249        request: &RequestMetadata,
250    ) -> Result<()> {
251        // Skip if no rate limit configured
252        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
253            (Some(r), Some(p)) => (r, p),
254            _ => return Ok(()),
255        };
256
257        // Build rate limit config
258        let key_str = info.rate_limit_key.unwrap_or("user");
259        let key_type: RateLimitKey = match key_str.parse() {
260            Ok(k) => k,
261            Err(_) => {
262                tracing::warn!(
263                    function = %function_name,
264                    key = %key_str,
265                    "Invalid rate limit key, falling back to 'user'"
266                );
267                RateLimitKey::default()
268            }
269        };
270
271        let config =
272            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
273
274        // Build bucket key
275        let bucket_key = self
276            .rate_limiter
277            .build_key(key_type, function_name, auth, request);
278
279        // Enforce rate limit
280        self.rate_limiter.enforce(&bucket_key, &config).await?;
281
282        Ok(())
283    }
284
285    /// Get the function kind by name.
286    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
287        self.registry.get(function_name).map(|e| e.kind())
288    }
289
290    /// Check if a function exists.
291    pub fn has_function(&self, function_name: &str) -> bool {
292        self.registry.get(function_name).is_some()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_check_auth_public() {
302        let info = FunctionInfo {
303            name: "test",
304            description: None,
305            kind: FunctionKind::Query,
306            required_role: None,
307            is_public: true,
308            cache_ttl: None,
309            timeout: None,
310            rate_limit_requests: None,
311            rate_limit_per_secs: None,
312            rate_limit_key: None,
313            log_level: None,
314            table_dependencies: &[],
315            transactional: false,
316        };
317
318        let _auth = AuthContext::unauthenticated();
319
320        // Can't test check_auth directly without a router instance,
321        // but we can test the logic
322        assert!(info.is_public);
323    }
324}