forge_runtime/function/
router.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{
5    ActionContext, AuthContext, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
6    MutationContext, QueryContext, RequestMetadata, Result, WorkflowDispatch,
7    rate_limit::{RateLimitConfig, RateLimitKey},
8};
9use serde_json::Value;
10
11use super::cache::QueryCache;
12use super::registry::{FunctionEntry, FunctionRegistry};
13use crate::rate_limit::RateLimiter;
14
15/// Result of routing a function call.
16pub enum RouteResult {
17    /// Query execution result.
18    Query(Value),
19    /// Mutation execution result.
20    Mutation(Value),
21    /// Action execution result.
22    Action(Value),
23}
24
25/// Routes function calls to the appropriate handler.
26pub struct FunctionRouter {
27    registry: Arc<FunctionRegistry>,
28    db_pool: sqlx::PgPool,
29    http_client: reqwest::Client,
30    job_dispatcher: Option<Arc<dyn JobDispatch>>,
31    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
32    rate_limiter: RateLimiter,
33    query_cache: QueryCache,
34}
35
36impl FunctionRouter {
37    /// Create a new function router.
38    pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
39        let rate_limiter = RateLimiter::new(db_pool.clone());
40        Self {
41            registry,
42            db_pool,
43            http_client: reqwest::Client::new(),
44            job_dispatcher: None,
45            workflow_dispatcher: None,
46            rate_limiter,
47            query_cache: QueryCache::new(),
48        }
49    }
50
51    /// Create a new function router with a custom HTTP client.
52    pub fn with_http_client(
53        registry: Arc<FunctionRegistry>,
54        db_pool: sqlx::PgPool,
55        http_client: reqwest::Client,
56    ) -> Self {
57        let rate_limiter = RateLimiter::new(db_pool.clone());
58        Self {
59            registry,
60            db_pool,
61            http_client,
62            job_dispatcher: None,
63            workflow_dispatcher: None,
64            rate_limiter,
65            query_cache: QueryCache::new(),
66        }
67    }
68
69    /// Set the job dispatcher for this router.
70    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
71        self.job_dispatcher = Some(dispatcher);
72        self
73    }
74
75    /// Set the workflow dispatcher for this router.
76    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
77        self.workflow_dispatcher = Some(dispatcher);
78        self
79    }
80
81    /// Route and execute a function call.
82    pub async fn route(
83        &self,
84        function_name: &str,
85        args: Value,
86        auth: AuthContext,
87        request: RequestMetadata,
88    ) -> Result<RouteResult> {
89        let entry = self.registry.get(function_name).ok_or_else(|| {
90            ForgeError::NotFound(format!("Function '{}' not found", function_name))
91        })?;
92
93        // Check authorization
94        self.check_auth(entry.info(), &auth)?;
95
96        // Check rate limit
97        self.check_rate_limit(entry.info(), function_name, &auth, &request)
98            .await?;
99
100        match entry {
101            FunctionEntry::Query { handler, info, .. } => {
102                // Check cache first if TTL is configured
103                if let Some(ttl) = info.cache_ttl {
104                    if let Some(cached) = self.query_cache.get(function_name, &args) {
105                        return Ok(RouteResult::Query(cached));
106                    }
107
108                    // Execute and cache result
109                    let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
110                    let result = handler(&ctx, args.clone()).await?;
111
112                    self.query_cache.set(
113                        function_name,
114                        &args,
115                        result.clone(),
116                        Duration::from_secs(ttl),
117                    );
118
119                    Ok(RouteResult::Query(result))
120                } else {
121                    let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
122                    let result = handler(&ctx, args).await?;
123                    Ok(RouteResult::Query(result))
124                }
125            }
126            FunctionEntry::Mutation { handler, .. } => {
127                let ctx = MutationContext::with_dispatch(
128                    self.db_pool.clone(),
129                    auth,
130                    request,
131                    self.job_dispatcher.clone(),
132                    self.workflow_dispatcher.clone(),
133                );
134                let result = handler(&ctx, args).await?;
135                Ok(RouteResult::Mutation(result))
136            }
137            FunctionEntry::Action { handler, .. } => {
138                let ctx = ActionContext::with_dispatch(
139                    self.db_pool.clone(),
140                    auth,
141                    request,
142                    self.http_client.clone(),
143                    self.job_dispatcher.clone(),
144                    self.workflow_dispatcher.clone(),
145                );
146                let result = handler(&ctx, args).await?;
147                Ok(RouteResult::Action(result))
148            }
149        }
150    }
151
152    /// Check authorization for a function call.
153    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
154        // Public functions don't require auth
155        if info.is_public {
156            return Ok(());
157        }
158
159        // Check if auth is required
160        if info.requires_auth && !auth.is_authenticated() {
161            return Err(ForgeError::Unauthorized("Authentication required".into()));
162        }
163
164        // Check role requirement
165        if let Some(role) = info.required_role {
166            if !auth.has_role(role) {
167                return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
168            }
169        }
170
171        Ok(())
172    }
173
174    /// Check rate limit for a function call.
175    async fn check_rate_limit(
176        &self,
177        info: &FunctionInfo,
178        function_name: &str,
179        auth: &AuthContext,
180        request: &RequestMetadata,
181    ) -> Result<()> {
182        // Skip if no rate limit configured
183        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
184            (Some(r), Some(p)) => (r, p),
185            _ => return Ok(()),
186        };
187
188        // Build rate limit config
189        let key_type: RateLimitKey = info
190            .rate_limit_key
191            .unwrap_or("user")
192            .parse()
193            .unwrap_or_default();
194
195        let config =
196            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
197
198        // Build bucket key
199        let bucket_key = self
200            .rate_limiter
201            .build_key(key_type, function_name, auth, request);
202
203        // Enforce rate limit
204        self.rate_limiter.enforce(&bucket_key, &config).await?;
205
206        Ok(())
207    }
208
209    /// Get the function kind by name.
210    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
211        self.registry.get(function_name).map(|e| e.kind())
212    }
213
214    /// Check if a function exists.
215    pub fn has_function(&self, function_name: &str) -> bool {
216        self.registry.get(function_name).is_some()
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_check_auth_public() {
226        let info = FunctionInfo {
227            name: "test",
228            description: None,
229            kind: FunctionKind::Query,
230            requires_auth: false,
231            required_role: None,
232            is_public: true,
233            cache_ttl: None,
234            timeout: None,
235            rate_limit_requests: None,
236            rate_limit_per_secs: None,
237            rate_limit_key: None,
238        };
239
240        let _auth = AuthContext::unauthenticated();
241
242        // Can't test check_auth directly without a router instance,
243        // but we can test the logic
244        assert!(info.is_public);
245    }
246}