1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{
5 AuthContext, ForgeError, FunctionInfo, FunctionKind, JobDispatch, MutationContext,
6 QueryContext, RequestMetadata, Result, WorkflowDispatch,
7 rate_limit::{RateLimitConfig, RateLimitKey},
8};
9use serde_json::Value;
10
11use super::cache::QueryCache;
12use super::registry::{FunctionEntry, FunctionRegistry};
13use crate::rate_limit::RateLimiter;
14
15pub enum RouteResult {
17 Query(Value),
19 Mutation(Value),
21 Job(Value),
23 Workflow(Value),
25}
26
27pub struct FunctionRouter {
29 registry: Arc<FunctionRegistry>,
30 db_pool: sqlx::PgPool,
31 http_client: reqwest::Client,
32 job_dispatcher: Option<Arc<dyn JobDispatch>>,
33 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
34 rate_limiter: RateLimiter,
35 query_cache: QueryCache,
36}
37
38impl FunctionRouter {
39 pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
41 let rate_limiter = RateLimiter::new(db_pool.clone());
42 Self {
43 registry,
44 db_pool,
45 http_client: reqwest::Client::new(),
46 job_dispatcher: None,
47 workflow_dispatcher: None,
48 rate_limiter,
49 query_cache: QueryCache::new(),
50 }
51 }
52
53 pub fn with_http_client(
55 registry: Arc<FunctionRegistry>,
56 db_pool: sqlx::PgPool,
57 http_client: reqwest::Client,
58 ) -> Self {
59 let rate_limiter = RateLimiter::new(db_pool.clone());
60 Self {
61 registry,
62 db_pool,
63 http_client,
64 job_dispatcher: None,
65 workflow_dispatcher: None,
66 rate_limiter,
67 query_cache: QueryCache::new(),
68 }
69 }
70
71 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
73 self.job_dispatcher = Some(dispatcher);
74 self
75 }
76
77 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
79 self.workflow_dispatcher = Some(dispatcher);
80 self
81 }
82
83 pub async fn route(
85 &self,
86 function_name: &str,
87 args: Value,
88 auth: AuthContext,
89 request: RequestMetadata,
90 ) -> Result<RouteResult> {
91 if let Some(entry) = self.registry.get(function_name) {
93 self.check_auth(entry.info(), &auth)?;
95
96 self.check_rate_limit(entry.info(), function_name, &auth, &request)
98 .await?;
99
100 return match entry {
101 FunctionEntry::Query { handler, info, .. } => {
102 if let Some(ttl) = info.cache_ttl {
103 if let Some(cached) = self.query_cache.get(function_name, &args) {
104 return Ok(RouteResult::Query(cached));
105 }
106
107 let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
109 let result = handler(&ctx, args.clone()).await?;
110
111 self.query_cache.set(
112 function_name,
113 &args,
114 result.clone(),
115 Duration::from_secs(ttl),
116 );
117
118 Ok(RouteResult::Query(result))
119 } else {
120 let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
121 let result = handler(&ctx, args).await?;
122 Ok(RouteResult::Query(result))
123 }
124 }
125 FunctionEntry::Mutation { handler, .. } => {
126 let ctx = MutationContext::with_dispatch(
127 self.db_pool.clone(),
128 auth,
129 request,
130 self.http_client.clone(),
131 self.job_dispatcher.clone(),
132 self.workflow_dispatcher.clone(),
133 );
134 let result = handler(&ctx, args).await?;
135 Ok(RouteResult::Mutation(result))
136 }
137 };
138 }
139
140 if let Some(ref job_dispatcher) = self.job_dispatcher {
142 if let Some(job_info) = job_dispatcher.get_info(function_name) {
143 self.check_job_auth(&job_info, &auth)?;
144 match job_dispatcher
145 .dispatch_by_name(function_name, args.clone())
146 .await
147 {
148 Ok(job_id) => {
149 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
150 }
151 Err(ForgeError::NotFound(_)) => {}
152 Err(e) => return Err(e),
153 }
154 }
155 }
156
157 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher {
159 if let Some(workflow_info) = workflow_dispatcher.get_info(function_name) {
160 self.check_workflow_auth(&workflow_info, &auth)?;
161 match workflow_dispatcher
162 .start_by_name(function_name, args.clone())
163 .await
164 {
165 Ok(workflow_id) => {
166 return Ok(RouteResult::Workflow(
167 serde_json::json!({ "workflow_id": workflow_id }),
168 ));
169 }
170 Err(ForgeError::NotFound(_)) => {}
171 Err(e) => return Err(e),
172 }
173 }
174 }
175
176 Err(ForgeError::NotFound(format!(
178 "Function '{}' not found",
179 function_name
180 )))
181 }
182
183 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
185 if info.is_public {
187 return Ok(());
188 }
189
190 if let Some(role) = info.required_role {
192 if !auth.is_authenticated() {
193 return Err(ForgeError::Unauthorized("Authentication required".into()));
194 }
195 if !auth.has_role(role) {
196 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
197 }
198 }
199
200 Ok(())
201 }
202
203 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
205 if info.is_public {
206 return Ok(());
207 }
208
209 if let Some(role) = info.required_role {
210 if !auth.is_authenticated() {
211 return Err(ForgeError::Unauthorized("Authentication required".into()));
212 }
213 if !auth.has_role(role) {
214 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
215 }
216 }
217
218 Ok(())
219 }
220
221 fn check_workflow_auth(
223 &self,
224 info: &forge_core::workflow::WorkflowInfo,
225 auth: &AuthContext,
226 ) -> Result<()> {
227 if info.is_public {
228 return Ok(());
229 }
230
231 if let Some(role) = info.required_role {
232 if !auth.is_authenticated() {
233 return Err(ForgeError::Unauthorized("Authentication required".into()));
234 }
235 if !auth.has_role(role) {
236 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
237 }
238 }
239
240 Ok(())
241 }
242
243 async fn check_rate_limit(
245 &self,
246 info: &FunctionInfo,
247 function_name: &str,
248 auth: &AuthContext,
249 request: &RequestMetadata,
250 ) -> Result<()> {
251 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
253 (Some(r), Some(p)) => (r, p),
254 _ => return Ok(()),
255 };
256
257 let key_str = info.rate_limit_key.unwrap_or("user");
259 let key_type: RateLimitKey = match key_str.parse() {
260 Ok(k) => k,
261 Err(_) => {
262 tracing::warn!(
263 function = %function_name,
264 key = %key_str,
265 "Invalid rate limit key, falling back to 'user'"
266 );
267 RateLimitKey::default()
268 }
269 };
270
271 let config =
272 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
273
274 let bucket_key = self
276 .rate_limiter
277 .build_key(key_type, function_name, auth, request);
278
279 self.rate_limiter.enforce(&bucket_key, &config).await?;
281
282 Ok(())
283 }
284
285 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
287 self.registry.get(function_name).map(|e| e.kind())
288 }
289
290 pub fn has_function(&self, function_name: &str) -> bool {
292 self.registry.get(function_name).is_some()
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_check_auth_public() {
302 let info = FunctionInfo {
303 name: "test",
304 description: None,
305 kind: FunctionKind::Query,
306 required_role: None,
307 is_public: true,
308 cache_ttl: None,
309 timeout: None,
310 rate_limit_requests: None,
311 rate_limit_per_secs: None,
312 rate_limit_key: None,
313 log_level: None,
314 table_dependencies: &[],
315 transactional: false,
316 };
317
318 let _auth = AuthContext::unauthenticated();
319
320 assert!(info.is_public);
323 }
324}