forge_runtime/function/
router.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{
5    AuthContext, ForgeError, FunctionInfo, FunctionKind, JobDispatch, MutationContext,
6    QueryContext, RequestMetadata, Result, WorkflowDispatch,
7    rate_limit::{RateLimitConfig, RateLimitKey},
8};
9use serde_json::Value;
10
11use super::cache::QueryCache;
12use super::registry::{FunctionEntry, FunctionRegistry};
13use crate::rate_limit::RateLimiter;
14
15/// Result of routing a function call.
16pub enum RouteResult {
17    /// Query execution result.
18    Query(Value),
19    /// Mutation execution result.
20    Mutation(Value),
21}
22
23/// Routes function calls to the appropriate handler.
24pub struct FunctionRouter {
25    registry: Arc<FunctionRegistry>,
26    db_pool: sqlx::PgPool,
27    http_client: reqwest::Client,
28    job_dispatcher: Option<Arc<dyn JobDispatch>>,
29    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
30    rate_limiter: RateLimiter,
31    query_cache: QueryCache,
32}
33
34impl FunctionRouter {
35    /// Create a new function router.
36    pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
37        let rate_limiter = RateLimiter::new(db_pool.clone());
38        Self {
39            registry,
40            db_pool,
41            http_client: reqwest::Client::new(),
42            job_dispatcher: None,
43            workflow_dispatcher: None,
44            rate_limiter,
45            query_cache: QueryCache::new(),
46        }
47    }
48
49    /// Create a new function router with a custom HTTP client.
50    pub fn with_http_client(
51        registry: Arc<FunctionRegistry>,
52        db_pool: sqlx::PgPool,
53        http_client: reqwest::Client,
54    ) -> Self {
55        let rate_limiter = RateLimiter::new(db_pool.clone());
56        Self {
57            registry,
58            db_pool,
59            http_client,
60            job_dispatcher: None,
61            workflow_dispatcher: None,
62            rate_limiter,
63            query_cache: QueryCache::new(),
64        }
65    }
66
67    /// Set the job dispatcher for this router.
68    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
69        self.job_dispatcher = Some(dispatcher);
70        self
71    }
72
73    /// Set the workflow dispatcher for this router.
74    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
75        self.workflow_dispatcher = Some(dispatcher);
76        self
77    }
78
79    /// Route and execute a function call.
80    pub async fn route(
81        &self,
82        function_name: &str,
83        args: Value,
84        auth: AuthContext,
85        request: RequestMetadata,
86    ) -> Result<RouteResult> {
87        let entry = self.registry.get(function_name).ok_or_else(|| {
88            ForgeError::NotFound(format!("Function '{}' not found", function_name))
89        })?;
90
91        // Check authorization
92        self.check_auth(entry.info(), &auth)?;
93
94        // Check rate limit
95        self.check_rate_limit(entry.info(), function_name, &auth, &request)
96            .await?;
97
98        match entry {
99            FunctionEntry::Query { handler, info, .. } => {
100                // Check cache first if TTL is configured
101                if let Some(ttl) = info.cache_ttl {
102                    if let Some(cached) = self.query_cache.get(function_name, &args) {
103                        return Ok(RouteResult::Query(cached));
104                    }
105
106                    // Execute and cache result
107                    let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
108                    let result = handler(&ctx, args.clone()).await?;
109
110                    self.query_cache.set(
111                        function_name,
112                        &args,
113                        result.clone(),
114                        Duration::from_secs(ttl),
115                    );
116
117                    Ok(RouteResult::Query(result))
118                } else {
119                    let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
120                    let result = handler(&ctx, args).await?;
121                    Ok(RouteResult::Query(result))
122                }
123            }
124            FunctionEntry::Mutation { handler, .. } => {
125                let ctx = MutationContext::with_dispatch(
126                    self.db_pool.clone(),
127                    auth,
128                    request,
129                    self.http_client.clone(),
130                    self.job_dispatcher.clone(),
131                    self.workflow_dispatcher.clone(),
132                );
133                let result = handler(&ctx, args).await?;
134                Ok(RouteResult::Mutation(result))
135            }
136        }
137    }
138
139    /// Check authorization for a function call.
140    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
141        // Public functions don't require auth
142        if info.is_public {
143            return Ok(());
144        }
145
146        // Check if auth is required
147        if info.requires_auth && !auth.is_authenticated() {
148            return Err(ForgeError::Unauthorized("Authentication required".into()));
149        }
150
151        // Check role requirement
152        if let Some(role) = info.required_role {
153            // Role check implies authentication is required
154            if !auth.is_authenticated() {
155                return Err(ForgeError::Unauthorized("Authentication required".into()));
156            }
157            if !auth.has_role(role) {
158                return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
159            }
160        }
161
162        Ok(())
163    }
164
165    /// Check rate limit for a function call.
166    async fn check_rate_limit(
167        &self,
168        info: &FunctionInfo,
169        function_name: &str,
170        auth: &AuthContext,
171        request: &RequestMetadata,
172    ) -> Result<()> {
173        // Skip if no rate limit configured
174        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
175            (Some(r), Some(p)) => (r, p),
176            _ => return Ok(()),
177        };
178
179        // Build rate limit config
180        let key_type: RateLimitKey = info
181            .rate_limit_key
182            .unwrap_or("user")
183            .parse()
184            .unwrap_or_default();
185
186        let config =
187            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
188
189        // Build bucket key
190        let bucket_key = self
191            .rate_limiter
192            .build_key(key_type, function_name, auth, request);
193
194        // Enforce rate limit
195        self.rate_limiter.enforce(&bucket_key, &config).await?;
196
197        Ok(())
198    }
199
200    /// Get the function kind by name.
201    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
202        self.registry.get(function_name).map(|e| e.kind())
203    }
204
205    /// Check if a function exists.
206    pub fn has_function(&self, function_name: &str) -> bool {
207        self.registry.get(function_name).is_some()
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_check_auth_public() {
217        let info = FunctionInfo {
218            name: "test",
219            description: None,
220            kind: FunctionKind::Query,
221            requires_auth: false,
222            required_role: None,
223            is_public: true,
224            cache_ttl: None,
225            timeout: None,
226            rate_limit_requests: None,
227            rate_limit_per_secs: None,
228            rate_limit_key: None,
229            log_level: None,
230            table_dependencies: &[],
231            transactional: false,
232        };
233
234        let _auth = AuthContext::unauthenticated();
235
236        // Can't test check_auth directly without a router instance,
237        // but we can test the logic
238        assert!(info.is_public);
239    }
240}