1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{
5 ActionContext, AuthContext, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
6 MutationContext, QueryContext, RequestMetadata, Result, WorkflowDispatch,
7 rate_limit::{RateLimitConfig, RateLimitKey},
8};
9use serde_json::Value;
10
11use super::cache::QueryCache;
12use super::registry::{FunctionEntry, FunctionRegistry};
13use crate::rate_limit::RateLimiter;
14
15pub enum RouteResult {
17 Query(Value),
19 Mutation(Value),
21 Action(Value),
23}
24
25pub struct FunctionRouter {
27 registry: Arc<FunctionRegistry>,
28 db_pool: sqlx::PgPool,
29 http_client: reqwest::Client,
30 job_dispatcher: Option<Arc<dyn JobDispatch>>,
31 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
32 rate_limiter: RateLimiter,
33 query_cache: QueryCache,
34}
35
36impl FunctionRouter {
37 pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
39 let rate_limiter = RateLimiter::new(db_pool.clone());
40 Self {
41 registry,
42 db_pool,
43 http_client: reqwest::Client::new(),
44 job_dispatcher: None,
45 workflow_dispatcher: None,
46 rate_limiter,
47 query_cache: QueryCache::new(),
48 }
49 }
50
51 pub fn with_http_client(
53 registry: Arc<FunctionRegistry>,
54 db_pool: sqlx::PgPool,
55 http_client: reqwest::Client,
56 ) -> Self {
57 let rate_limiter = RateLimiter::new(db_pool.clone());
58 Self {
59 registry,
60 db_pool,
61 http_client,
62 job_dispatcher: None,
63 workflow_dispatcher: None,
64 rate_limiter,
65 query_cache: QueryCache::new(),
66 }
67 }
68
69 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
71 self.job_dispatcher = Some(dispatcher);
72 self
73 }
74
75 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
77 self.workflow_dispatcher = Some(dispatcher);
78 self
79 }
80
81 pub async fn route(
83 &self,
84 function_name: &str,
85 args: Value,
86 auth: AuthContext,
87 request: RequestMetadata,
88 ) -> Result<RouteResult> {
89 let entry = self.registry.get(function_name).ok_or_else(|| {
90 ForgeError::NotFound(format!("Function '{}' not found", function_name))
91 })?;
92
93 self.check_auth(entry.info(), &auth)?;
95
96 self.check_rate_limit(entry.info(), function_name, &auth, &request)
98 .await?;
99
100 match entry {
101 FunctionEntry::Query { handler, info, .. } => {
102 if let Some(ttl) = info.cache_ttl {
104 if let Some(cached) = self.query_cache.get(function_name, &args) {
105 return Ok(RouteResult::Query(cached));
106 }
107
108 let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
110 let result = handler(&ctx, args.clone()).await?;
111
112 self.query_cache.set(
113 function_name,
114 &args,
115 result.clone(),
116 Duration::from_secs(ttl),
117 );
118
119 Ok(RouteResult::Query(result))
120 } else {
121 let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
122 let result = handler(&ctx, args).await?;
123 Ok(RouteResult::Query(result))
124 }
125 }
126 FunctionEntry::Mutation { handler, .. } => {
127 let ctx = MutationContext::with_dispatch(
128 self.db_pool.clone(),
129 auth,
130 request,
131 self.job_dispatcher.clone(),
132 self.workflow_dispatcher.clone(),
133 );
134 let result = handler(&ctx, args).await?;
135 Ok(RouteResult::Mutation(result))
136 }
137 FunctionEntry::Action { handler, .. } => {
138 let ctx = ActionContext::with_dispatch(
139 self.db_pool.clone(),
140 auth,
141 request,
142 self.http_client.clone(),
143 self.job_dispatcher.clone(),
144 self.workflow_dispatcher.clone(),
145 );
146 let result = handler(&ctx, args).await?;
147 Ok(RouteResult::Action(result))
148 }
149 }
150 }
151
152 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
154 if info.is_public {
156 return Ok(());
157 }
158
159 if info.requires_auth && !auth.is_authenticated() {
161 return Err(ForgeError::Unauthorized("Authentication required".into()));
162 }
163
164 if let Some(role) = info.required_role {
166 if !auth.has_role(role) {
167 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
168 }
169 }
170
171 Ok(())
172 }
173
174 async fn check_rate_limit(
176 &self,
177 info: &FunctionInfo,
178 function_name: &str,
179 auth: &AuthContext,
180 request: &RequestMetadata,
181 ) -> Result<()> {
182 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
184 (Some(r), Some(p)) => (r, p),
185 _ => return Ok(()),
186 };
187
188 let key_type: RateLimitKey = info
190 .rate_limit_key
191 .unwrap_or("user")
192 .parse()
193 .unwrap_or_default();
194
195 let config =
196 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
197
198 let bucket_key = self
200 .rate_limiter
201 .build_key(key_type, function_name, auth, request);
202
203 self.rate_limiter.enforce(&bucket_key, &config).await?;
205
206 Ok(())
207 }
208
209 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
211 self.registry.get(function_name).map(|e| e.kind())
212 }
213
214 pub fn has_function(&self, function_name: &str) -> bool {
216 self.registry.get(function_name).is_some()
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_check_auth_public() {
226 let info = FunctionInfo {
227 name: "test",
228 description: None,
229 kind: FunctionKind::Query,
230 requires_auth: false,
231 required_role: None,
232 is_public: true,
233 cache_ttl: None,
234 timeout: None,
235 rate_limit_requests: None,
236 rate_limit_per_secs: None,
237 rate_limit_key: None,
238 };
239
240 let _auth = AuthContext::unauthenticated();
241
242 assert!(info.is_public);
245 }
246}