1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{
5 AuthContext, ForgeError, FunctionInfo, FunctionKind, JobDispatch, MutationContext,
6 QueryContext, RequestMetadata, Result, WorkflowDispatch,
7 rate_limit::{RateLimitConfig, RateLimitKey},
8};
9use serde_json::Value;
10
11use super::cache::QueryCache;
12use super::registry::{FunctionEntry, FunctionRegistry};
13use crate::rate_limit::RateLimiter;
14
15pub enum RouteResult {
17 Query(Value),
19 Mutation(Value),
21 Job(Value),
23 Workflow(Value),
25}
26
27pub struct FunctionRouter {
29 registry: Arc<FunctionRegistry>,
30 db_pool: sqlx::PgPool,
31 http_client: reqwest::Client,
32 job_dispatcher: Option<Arc<dyn JobDispatch>>,
33 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
34 rate_limiter: RateLimiter,
35 query_cache: QueryCache,
36}
37
38impl FunctionRouter {
39 pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
41 let rate_limiter = RateLimiter::new(db_pool.clone());
42 Self {
43 registry,
44 db_pool,
45 http_client: reqwest::Client::new(),
46 job_dispatcher: None,
47 workflow_dispatcher: None,
48 rate_limiter,
49 query_cache: QueryCache::new(),
50 }
51 }
52
53 pub fn with_http_client(
55 registry: Arc<FunctionRegistry>,
56 db_pool: sqlx::PgPool,
57 http_client: reqwest::Client,
58 ) -> Self {
59 let rate_limiter = RateLimiter::new(db_pool.clone());
60 Self {
61 registry,
62 db_pool,
63 http_client,
64 job_dispatcher: None,
65 workflow_dispatcher: None,
66 rate_limiter,
67 query_cache: QueryCache::new(),
68 }
69 }
70
71 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
73 self.job_dispatcher = Some(dispatcher);
74 self
75 }
76
77 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
79 self.workflow_dispatcher = Some(dispatcher);
80 self
81 }
82
83 pub async fn route(
85 &self,
86 function_name: &str,
87 args: Value,
88 auth: AuthContext,
89 request: RequestMetadata,
90 ) -> Result<RouteResult> {
91 if let Some(entry) = self.registry.get(function_name) {
93 self.check_auth(entry.info(), &auth)?;
95
96 self.check_rate_limit(entry.info(), function_name, &auth, &request)
98 .await?;
99
100 return match entry {
101 FunctionEntry::Query { handler, info, .. } => {
102 if let Some(ttl) = info.cache_ttl {
103 if let Some(cached) = self.query_cache.get(function_name, &args) {
104 return Ok(RouteResult::Query(cached));
105 }
106
107 let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
109 let result = handler(&ctx, args.clone()).await?;
110
111 self.query_cache.set(
112 function_name,
113 &args,
114 result.clone(),
115 Duration::from_secs(ttl),
116 );
117
118 Ok(RouteResult::Query(result))
119 } else {
120 let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
121 let result = handler(&ctx, args).await?;
122 Ok(RouteResult::Query(result))
123 }
124 }
125 FunctionEntry::Mutation { handler, .. } => {
126 let ctx = MutationContext::with_dispatch(
127 self.db_pool.clone(),
128 auth,
129 request,
130 self.http_client.clone(),
131 self.job_dispatcher.clone(),
132 self.workflow_dispatcher.clone(),
133 );
134 let result = handler(&ctx, args).await?;
135 Ok(RouteResult::Mutation(result))
136 }
137 };
138 }
139
140 if let Some(ref job_dispatcher) = self.job_dispatcher {
142 if let Some(job_info) = job_dispatcher.get_info(function_name) {
143 self.check_job_auth(&job_info, &auth)?;
144 match job_dispatcher
145 .dispatch_by_name(function_name, args.clone())
146 .await
147 {
148 Ok(job_id) => {
149 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
150 }
151 Err(ForgeError::NotFound(_)) => {}
152 Err(e) => return Err(e),
153 }
154 }
155 }
156
157 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher {
159 if let Some(workflow_info) = workflow_dispatcher.get_info(function_name) {
160 self.check_workflow_auth(&workflow_info, &auth)?;
161 match workflow_dispatcher
162 .start_by_name(function_name, args.clone())
163 .await
164 {
165 Ok(workflow_id) => {
166 return Ok(RouteResult::Workflow(
167 serde_json::json!({ "workflow_id": workflow_id }),
168 ));
169 }
170 Err(ForgeError::NotFound(_)) => {}
171 Err(e) => return Err(e),
172 }
173 }
174 }
175
176 Err(ForgeError::NotFound(format!(
178 "Function '{}' not found",
179 function_name
180 )))
181 }
182
183 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
185 if info.is_public {
186 return Ok(());
187 }
188
189 if !auth.is_authenticated() {
191 return Err(ForgeError::Unauthorized("Authentication required".into()));
192 }
193
194 if let Some(role) = info.required_role {
195 if !auth.has_role(role) {
196 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
197 }
198 }
199
200 Ok(())
201 }
202
203 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
205 if info.is_public {
206 return Ok(());
207 }
208
209 if !auth.is_authenticated() {
211 return Err(ForgeError::Unauthorized("Authentication required".into()));
212 }
213
214 if let Some(role) = info.required_role {
215 if !auth.has_role(role) {
216 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
217 }
218 }
219
220 Ok(())
221 }
222
223 fn check_workflow_auth(
225 &self,
226 info: &forge_core::workflow::WorkflowInfo,
227 auth: &AuthContext,
228 ) -> Result<()> {
229 if info.is_public {
230 return Ok(());
231 }
232
233 if !auth.is_authenticated() {
235 return Err(ForgeError::Unauthorized("Authentication required".into()));
236 }
237
238 if let Some(role) = info.required_role {
239 if !auth.has_role(role) {
240 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
241 }
242 }
243
244 Ok(())
245 }
246
247 async fn check_rate_limit(
249 &self,
250 info: &FunctionInfo,
251 function_name: &str,
252 auth: &AuthContext,
253 request: &RequestMetadata,
254 ) -> Result<()> {
255 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
257 (Some(r), Some(p)) => (r, p),
258 _ => return Ok(()),
259 };
260
261 let key_str = info.rate_limit_key.unwrap_or("user");
263 let key_type: RateLimitKey = match key_str.parse() {
264 Ok(k) => k,
265 Err(_) => {
266 tracing::warn!(
267 function = %function_name,
268 key = %key_str,
269 "Invalid rate limit key, falling back to 'user'"
270 );
271 RateLimitKey::default()
272 }
273 };
274
275 let config =
276 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
277
278 let bucket_key = self
280 .rate_limiter
281 .build_key(key_type, function_name, auth, request);
282
283 self.rate_limiter.enforce(&bucket_key, &config).await?;
285
286 Ok(())
287 }
288
289 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
291 self.registry.get(function_name).map(|e| e.kind())
292 }
293
294 pub fn has_function(&self, function_name: &str) -> bool {
296 self.registry.get(function_name).is_some()
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_check_auth_public() {
306 let info = FunctionInfo {
307 name: "test",
308 description: None,
309 kind: FunctionKind::Query,
310 required_role: None,
311 is_public: true,
312 cache_ttl: None,
313 timeout: None,
314 rate_limit_requests: None,
315 rate_limit_per_secs: None,
316 rate_limit_key: None,
317 log_level: None,
318 table_dependencies: &[],
319 transactional: false,
320 };
321
322 let _auth = AuthContext::unauthenticated();
323
324 assert!(info.is_public);
327 }
328}